This commit is contained in:
khanon
2023-12-04 04:21:18 +00:00
parent cd1b9d0e0c
commit fbdea30264
31 changed files with 1237 additions and 216 deletions
+10 -32
View File
@@ -26,46 +26,23 @@ type AnthropicAPIError = {
type UpdateFn = typeof AnthropicKeyProvider.prototype.update;
export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
private readonly updateKey: UpdateFn;
constructor(keys: AnthropicKey[], updateKey: UpdateFn) {
super(keys, {
service: "anthropic",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
updateKey,
});
this.updateKey = updateKey;
}
protected async checkKey(key: AnthropicKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
}
this.log.debug({ key: key.hash }, "Checking key...");
let isInitialCheck = !key.lastChecked;
try {
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
const updates = { isPozzed: pozzed };
this.updateKey(key.hash, updates);
this.log.info(
{ key: key.hash, models: key.modelFamilies },
"Key check complete."
);
} catch (error) {
// touch the key so we don't check it again for a while
this.updateKey(key.hash, {});
this.handleAxiosError(key, error as AxiosError);
}
this.lastCheck = Date.now();
// Only enqueue the next check if this wasn't a startup check, since those
// are batched together elsewhere.
if (!isInitialCheck) {
this.scheduleNextCheck();
}
protected async testKeyOrFail(key: AnthropicKey) {
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
const updates = { isPozzed: pozzed };
this.updateKey(key.hash, updates);
this.log.info(
{ key: key.hash, models: key.modelFamilies },
"Checked key."
);
}
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
@@ -84,6 +61,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
{ key: key.hash, error: error.message },
"Key is rate limited. Rechecking in 10 seconds."
);
0;
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
this.updateKey(key.hash, { lastChecked: next });
break;
+19 -41
View File
@@ -32,58 +32,36 @@ type GetLoggingConfigResponse = {
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
private readonly updateKey: UpdateFn;
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
super(keys, {
service: "aws",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
updateKey,
});
this.updateKey = updateKey;
}
protected async checkKey(key: AwsBedrockKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
protected async testKeyOrFail(key: AwsBedrockKey) {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
const modelChecks: Promise<unknown>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
}
this.log.debug({ key: key.hash }, "Checking key...");
let isInitialCheck = !key.lastChecked;
try {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
const modelChecks: Promise<unknown>[] = [];
if (isInitialCheck) {
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
}
await Promise.all(modelChecks);
await this.checkLoggingConfiguration(key);
await Promise.all(modelChecks);
await this.checkLoggingConfiguration(key);
this.log.info(
{
key: key.hash,
models: key.modelFamilies,
logged: key.awsLoggingStatus,
},
"Key check complete."
);
} catch (error) {
this.handleAxiosError(key, error as AxiosError);
}
this.updateKey(key.hash, {});
this.lastCheck = Date.now();
// Only enqueue the next check if this wasn't a startup check, since those
// are batched together elsewhere.
if (!isInitialCheck) {
this.scheduleNextCheck();
}
this.log.info(
{
key: key.hash,
models: key.modelFamilies,
logged: key.awsLoggingStatus,
},
"Checked key."
);
}
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {
+149
View File
@@ -0,0 +1,149 @@
import axios, { AxiosError } from "axios";
import { KeyCheckerBase } from "../key-checker-base";
import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
import { getAzureOpenAIModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com";
const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) =>
`https://${AZURE_HOST.replace(
"%RESOURCE_NAME%",
resourceName
)}/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`;
type AzureError = {
error: {
message: string;
type: string | null;
param: string;
code: string;
status: number;
};
};
type UpdateFn = typeof AzureOpenAIKeyProvider.prototype.update;
export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
constructor(keys: AzureOpenAIKey[], updateKey: UpdateFn) {
super(keys, {
service: "azure",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
protected async testKeyOrFail(key: AzureOpenAIKey) {
const model = await this.testModel(key);
this.log.info(
{ key: key.hash, deploymentModel: model },
"Checked key."
);
this.updateKey(key.hash, { modelFamilies: [model] });
}
// provided api-key header isn't valid (401)
// {
// "error": {
// "code": "401",
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
// }
// }
// api key correct but deployment id is wrong (404)
// {
// "error": {
// "code": "DeploymentNotFound",
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
// }
// }
// resource name is wrong (node will throw ENOTFOUND)
// rate limited (429)
// TODO: try to reproduce this
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
const data = error.response.data;
const status = data.error.status;
const errorType = data.error.code || data.error.type;
switch (errorType) {
case "DeploymentNotFound":
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is revoked or deployment ID is incorrect. Disabling key."
);
return this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
});
case "401":
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is disabled or incorrect. Disabling key."
);
return this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
});
default:
this.log.error(
{ key: key.hash, errorType, error: error.response.data, status },
"Unknown Azure API error while checking key. Please report this."
);
return this.updateKey(key.hash, { lastChecked: Date.now() });
}
}
const { response, code } = error;
if (code === "ENOTFOUND") {
this.log.warn(
{ key: key.hash, error: error.message },
"Resource name is probably incorrect. Disabling key."
);
return this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
}
const { headers, status, data } = response ?? {};
this.log.error(
{ key: key.hash, status, headers, data, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
private async testModel(key: AzureOpenAIKey) {
const { apiKey, deploymentId, resourceName } =
AzureOpenAIKeyChecker.getCredentialsFromKey(key);
const url = POST_CHAT_COMPLETIONS(resourceName, deploymentId);
const testRequest = {
max_tokens: 1,
stream: false,
messages: [{ role: "user", content: "" }],
};
const { data } = await axios.post(url, testRequest, {
headers: { "Content-Type": "application/json", "api-key": apiKey },
});
return getAzureOpenAIModelFamily(data.model);
}
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
const data = error.response?.data as any;
return data?.error?.code || data?.error?.type;
}
static getCredentialsFromKey(key: AzureOpenAIKey) {
const [resourceName, deploymentId, apiKey] = key.key.split(":");
if (!resourceName || !deploymentId || !apiKey) {
throw new Error(
"Invalid Azure credential format. Refer to .env.example and ensure your credentials are in the format RESOURCE_NAME:DEPLOYMENT_ID:API_KEY with commas between each credential set."
);
}
return { resourceName, deploymentId, apiKey };
}
}
+212
View File
@@ -0,0 +1,212 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models";
import { OpenAIModel } from "../openai/provider";
import { AzureOpenAIKeyChecker } from "./checker";
import { AwsKeyChecker } from "../aws/checker";
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
type AzureOpenAIKeyUsage = {
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
};
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
readonly service: "azure";
readonly modelFamilies: AzureOpenAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
contentFiltering: boolean;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 4000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 250;
export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
readonly service = "azure";
private keys: AzureOpenAIKey[] = [];
private checker?: AzureOpenAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.azureCredentials;
if (!keyConfig) {
this.log.warn(
"AZURE_CREDENTIALS is not set. Azure OpenAI API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: AzureOpenAIKey = {
key,
service: this.service,
modelFamilies: ["azure-gpt4"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
contentFiltering: false,
hash: `azu-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"azure-turboTokens": 0,
"azure-gpt4Tokens": 0,
"azure-gpt4-32kTokens": 0,
"azure-gpt4-turboTokens": 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded Azure OpenAI keys.");
}
public init() {
if (config.checkKeys) {
this.checker = new AzureOpenAIKeyChecker(
this.keys,
this.update.bind(this)
);
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(model: AzureOpenAIModel) {
const neededFamily = getAzureOpenAIModelFamily(model);
const availableKeys = this.keys.filter(
(k) => !k.isDisabled && k.modelFamilies.includes(neededFamily)
);
if (availableKeys.length === 0) {
throw new Error(`No keys available for model family '${neededFamily}'.`);
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
public disable(key: AzureOpenAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<AzureOpenAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
}
// TODO: all of this shit is duplicate code
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve, so all we can do is wait a bit and try
* again. We will lock the key for 2 seconds after getting a 429 before
* retrying in order to give the other requests a chance to finish.
*/
public markRateLimited(keyHash: string) {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public recheck() {
this.keys.forEach(({ hash }) =>
this.update(hash, { lastChecked: 0, isDisabled: false })
);
}
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}
+10 -2
View File
@@ -2,6 +2,7 @@ import { OpenAIModel } from "./openai/provider";
import { AnthropicModel } from "./anthropic/provider";
import { GooglePalmModel } from "./palm/provider";
import { AwsBedrockModel } from "./aws/provider";
import { AzureOpenAIModel } from "./azure/provider";
import { KeyPool } from "./key-pool";
import type { ModelFamily } from "../models";
@@ -13,12 +14,18 @@ export type APIFormat =
| "openai-text"
| "openai-image";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
export type LLMService =
| "openai"
| "anthropic"
| "google-palm"
| "aws"
| "azure";
export type Model =
| OpenAIModel
| AnthropicModel
| GooglePalmModel
| AwsBedrockModel;
| AwsBedrockModel
| AzureOpenAIModel;
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
@@ -72,3 +79,4 @@ export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider";
export { AwsBedrockKey } from "./aws/provider";
export { AzureOpenAIKey } from "./azure/provider";
+57 -20
View File
@@ -3,14 +3,17 @@ import { logger } from "../../logger";
import { Key } from "./index";
import { AxiosError } from "axios";
type KeyCheckerOptions = {
type KeyCheckerOptions<TKey extends Key = Key> = {
service: string;
keyCheckPeriod: number;
minCheckInterval: number;
}
recurringChecksEnabled?: boolean;
updateKey: (hash: string, props: Partial<TKey>) => void;
};
export abstract class KeyCheckerBase<TKey extends Key> {
protected readonly service: string;
protected readonly RECURRING_CHECKS_ENABLED: boolean;
/** Minimum time in between any two key checks. */
protected readonly MIN_CHECK_INTERVAL: number;
/**
@@ -19,16 +22,19 @@ export abstract class KeyCheckerBase<TKey extends Key> {
* than this.
*/
protected readonly KEY_CHECK_PERIOD: number;
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
protected readonly keys: TKey[] = [];
protected log: pino.Logger;
protected timeout?: NodeJS.Timeout;
protected lastCheck = 0;
protected constructor(keys: TKey[], opts: KeyCheckerOptions) {
protected constructor(keys: TKey[], opts: KeyCheckerOptions<TKey>) {
const { service, keyCheckPeriod, minCheckInterval } = opts;
this.keys = keys;
this.KEY_CHECK_PERIOD = keyCheckPeriod;
this.MIN_CHECK_INTERVAL = minCheckInterval;
this.RECURRING_CHECKS_ENABLED = opts.recurringChecksEnabled ?? true;
this.updateKey = opts.updateKey;
this.service = service;
this.log = logger.child({ module: "key-checker", service });
}
@@ -52,31 +58,34 @@ export abstract class KeyCheckerBase<TKey extends Key> {
* the minimum check interval.
*/
public scheduleNextCheck() {
// Gives each concurrent check a correlation ID to make logs less confusing.
const callId = Math.random().toString(36).slice(2, 8);
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
const checkLog = this.log.child({ callId, timeoutId });
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
const numEnabled = enabledKeys.length;
const numUnchecked = uncheckedKeys.length;
clearTimeout(this.timeout);
this.timeout = undefined;
if (enabledKeys.length === 0) {
checkLog.warn("All keys are disabled. Key checker stopping.");
if (!numEnabled) {
checkLog.warn("All keys are disabled. Stopping.");
return;
}
// Perform startup checks for any keys that haven't been checked yet.
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
if (uncheckedKeys.length > 0) {
const keysToCheck = uncheckedKeys.slice(0, 12);
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
if (numUnchecked > 0) {
const keycheckBatch = uncheckedKeys.slice(0, 12);
this.timeout = setTimeout(async () => {
try {
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
await Promise.all(keycheckBatch.map((key) => this.checkKey(key)));
} catch (error) {
this.log.error({ error }, "Error checking one or more keys.");
checkLog.error({ error }, "Error checking one or more keys.");
}
checkLog.info("Batch complete.");
this.scheduleNextCheck();
@@ -84,11 +93,18 @@ export abstract class KeyCheckerBase<TKey extends Key> {
checkLog.info(
{
batch: keysToCheck.map((k) => k.hash),
remaining: uncheckedKeys.length - keysToCheck.length,
batch: keycheckBatch.map((k) => k.hash),
remaining: uncheckedKeys.length - keycheckBatch.length,
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
},
"Scheduled batch check."
"Scheduled batch of initial checks."
);
return;
}
if (!this.RECURRING_CHECKS_ENABLED) {
checkLog.info(
"Initial checks complete and recurring checks are disabled for this service. Stopping."
);
return;
}
@@ -106,14 +122,35 @@ export abstract class KeyCheckerBase<TKey extends Key> {
);
const delay = nextCheck - Date.now();
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
this.timeout = setTimeout(
() => this.checkKey(oldestKey).then(() => this.scheduleNextCheck()),
delay
);
checkLog.debug(
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
"Scheduled single key check."
"Scheduled next recurring check."
);
}
protected abstract checkKey(key: TKey): Promise<void>;
public async checkKey(key: TKey): Promise<void> {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
}
this.log.debug({ key: key.hash }, "Checking key...");
try {
await this.testKeyOrFail(key);
} catch (error) {
this.updateKey(key.hash, {});
this.handleAxiosError(key, error as AxiosError);
}
this.lastCheck = Date.now();
}
protected abstract testKeyOrFail(key: TKey): Promise<void>;
protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
}
}
+9
View File
@@ -11,6 +11,7 @@ import { GooglePalmKeyProvider } from "./palm/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { ModelFamily } from "../models";
import { assertNever } from "../utils";
import { AzureOpenAIKeyProvider } from "./azure/provider";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@@ -25,6 +26,7 @@ export class KeyPool {
this.keyProviders.push(new AnthropicKeyProvider());
this.keyProviders.push(new GooglePalmKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider());
}
public init() {
@@ -124,6 +126,8 @@ export class KeyPool {
// AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
return "aws";
} else if (model.startsWith("azure")) {
return "azure";
}
throw new Error(`Unknown service for model '${model}'`);
}
@@ -142,6 +146,11 @@ export class KeyPool {
return "google-palm";
case "aws-claude":
return "aws";
case "azure-turbo":
case "azure-gpt4":
case "azure-gpt4-32k":
case "azure-gpt4-turbo":
return "azure";
default:
assertNever(modelFamily);
}
+34 -47
View File
@@ -27,65 +27,41 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
private readonly cloneKey: CloneFn;
private readonly updateKey: UpdateFn;
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
super(keys, {
service: "openai",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
this.cloneKey = cloneFn;
this.updateKey = updateKey;
}
protected async checkKey(key: OpenAIKey) {
if (key.isDisabled) {
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
this.scheduleNextCheck();
return;
}
this.log.debug({ key: key.hash }, "Checking key...");
let isInitialCheck = !key.lastChecked;
try {
// We only need to check for provisioned models on the initial check.
if (isInitialCheck) {
const [provisionedModels, livenessTest] = await Promise.all([
this.getProvisionedModels(key),
this.testLiveness(key),
this.maybeCreateOrganizationClones(key),
]);
const updates = {
modelFamilies: provisionedModels,
isTrial: livenessTest.rateLimit <= 250,
};
this.updateKey(key.hash, updates);
} else {
// No updates needed as models and trial status generally don't change.
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
this.updateKey(key.hash, {});
}
this.log.info(
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
"Key check complete."
);
} catch (error) {
// touch the key so we don't check it again for a while
protected async testKeyOrFail(key: OpenAIKey) {
// We only need to check for provisioned models on the initial check.
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
const [provisionedModels, livenessTest] = await Promise.all([
this.getProvisionedModels(key),
this.testLiveness(key),
this.maybeCreateOrganizationClones(key),
]);
const updates = {
modelFamilies: provisionedModels,
isTrial: livenessTest.rateLimit <= 250,
};
this.updateKey(key.hash, updates);
} else {
// No updates needed as models and trial status generally don't change.
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
this.updateKey(key.hash, {});
this.handleAxiosError(key, error as AxiosError);
}
this.lastCheck = Date.now();
// Only enqueue the next check if this wasn't a startup check, since those
// are batched together elsewhere.
if (!isInitialCheck) {
this.log.info(
{ key: key.hash },
"Recurring keychecks are disabled, no-op."
);
// this.scheduleNextCheck();
}
this.log.info(
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
"Checked key."
);
}
private async getProvisionedModels(
@@ -138,6 +114,17 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
.filter(({ is_default }) => !is_default)
.map(({ id }) => id);
this.cloneKey(key.hash, ids);
// It's possible that the keychecker may be stopped if all non-cloned keys
// happened to be unusable, in which case this clnoe will never be checked
// unless we restart the keychecker.
if (!this.timeout) {
this.log.warn(
{ parent: key.hash },
"Restarting key checker to check cloned keys."
);
this.scheduleNextCheck();
}
}
protected handleAxiosError(key: OpenAIKey, error: AxiosError) {
@@ -217,17 +217,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
return a.lastUsed - b.lastUsed;
});
// logger.debug(
// {
// byPriority: keysByPriority.map((k) => ({
// hash: k.hash,
// isRateLimited: now - k.rateLimitedAt < rateLimitThreshold,
// modelFamilies: k.modelFamilies,
// })),
// },
// "Keys sorted by priority"
// );
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
+34 -2
View File
@@ -2,15 +2,25 @@
import pino from "pino";
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e";
export type OpenAIModelFamily =
| "turbo"
| "gpt4"
| "gpt4-32k"
| "gpt4-turbo"
| "dall-e";
export type AnthropicModelFamily = "claude";
export type GooglePalmModelFamily = "bison";
export type AwsBedrockModelFamily = "aws-claude";
export type AzureOpenAIModelFamily = `azure-${Exclude<
OpenAIModelFamily,
"dall-e"
>}`;
export type ModelFamily =
| OpenAIModelFamily
| AnthropicModelFamily
| GooglePalmModelFamily
| AwsBedrockModelFamily;
| AwsBedrockModelFamily
| AzureOpenAIModelFamily;
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
@@ -23,6 +33,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"claude",
"bison",
"aws-claude",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
"azure-gpt4-turbo",
] as const);
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
@@ -64,6 +78,24 @@ export function getAwsBedrockModelFamily(_model: string): ModelFamily {
return "aws-claude";
}
export function getAzureOpenAIModelFamily(
model: string,
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
): AzureOpenAIModelFamily {
// Azure model names omit periods. addAzureKey also prepends "azure-" to the
// model name to route the request the correct keyprovider, so we need to
// remove that as well.
const modified = model
.replace("gpt-35-turbo", "gpt-3.5-turbo")
.replace("azure-", "");
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
if (modified.match(regex)) {
return `azure-${family}` as AzureOpenAIModelFamily;
}
}
return defaultFamily;
}
export function assertIsKnownModelFamily(
modelFamily: string
): asserts modelFamily is ModelFamily {
+8
View File
@@ -12,6 +12,7 @@ import schedule from "node-schedule";
import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config";
import {
getAzureOpenAIModelFamily,
getClaudeModelFamily,
getGooglePalmModelFamily,
getOpenAIModelFamily,
@@ -34,6 +35,10 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
claude: 0,
bison: 0,
"aws-claude": 0,
"azure-turbo": 0,
"azure-gpt4": 0,
"azure-gpt4-turbo": 0,
"azure-gpt4-32k": 0,
};
const users: Map<string, User> = new Map();
@@ -382,6 +387,9 @@ function getModelFamilyForQuotaUsage(
model: string,
api: APIFormat
): ModelFamily {
// TODO: this seems incorrect
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
switch (api) {
case "openai":
case "openai-text":