Azure OpenAI suport (khanon/oai-reverse-proxy!48)
This commit is contained in:
@@ -26,46 +26,23 @@ type AnthropicAPIError = {
|
||||
type UpdateFn = typeof AnthropicKeyProvider.prototype.update;
|
||||
|
||||
export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: AnthropicKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "anthropic",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
updateKey,
|
||||
});
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async checkKey(key: AnthropicKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
|
||||
const updates = { isPozzed: pozzed };
|
||||
this.updateKey(key.hash, updates);
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies },
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
// touch the key so we don't check it again for a while
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
protected async testKeyOrFail(key: AnthropicKey) {
|
||||
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
|
||||
const updates = { isPozzed: pozzed };
|
||||
this.updateKey(key.hash, updates);
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies },
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
|
||||
@@ -84,6 +61,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||
{ key: key.hash, error: error.message },
|
||||
"Key is rate limited. Rechecking in 10 seconds."
|
||||
);
|
||||
0;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
break;
|
||||
|
||||
@@ -32,58 +32,36 @@ type GetLoggingConfigResponse = {
|
||||
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
|
||||
|
||||
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "aws",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
updateKey,
|
||||
});
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async checkKey(key: AwsBedrockKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
protected async testKeyOrFail(key: AwsBedrockKey) {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
const modelChecks: Promise<unknown>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
|
||||
}
|
||||
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
const modelChecks: Promise<unknown>[] = [];
|
||||
if (isInitialCheck) {
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
|
||||
}
|
||||
await Promise.all(modelChecks);
|
||||
await this.checkLoggingConfiguration(key);
|
||||
|
||||
await Promise.all(modelChecks);
|
||||
await this.checkLoggingConfiguration(key);
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
models: key.modelFamilies,
|
||||
logged: key.awsLoggingStatus,
|
||||
},
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {});
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
models: key.modelFamilies,
|
||||
logged: key.awsLoggingStatus,
|
||||
},
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
import axios, { AxiosError } from "axios";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
|
||||
const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com";
|
||||
const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) =>
|
||||
`https://${AZURE_HOST.replace(
|
||||
"%RESOURCE_NAME%",
|
||||
resourceName
|
||||
)}/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`;
|
||||
|
||||
type AzureError = {
|
||||
error: {
|
||||
message: string;
|
||||
type: string | null;
|
||||
param: string;
|
||||
code: string;
|
||||
status: number;
|
||||
};
|
||||
};
|
||||
type UpdateFn = typeof AzureOpenAIKeyProvider.prototype.update;
|
||||
|
||||
export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
||||
constructor(keys: AzureOpenAIKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "azure",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: AzureOpenAIKey) {
|
||||
const model = await this.testModel(key);
|
||||
this.log.info(
|
||||
{ key: key.hash, deploymentModel: model },
|
||||
"Checked key."
|
||||
);
|
||||
this.updateKey(key.hash, { modelFamilies: [model] });
|
||||
}
|
||||
|
||||
// provided api-key header isn't valid (401)
|
||||
// {
|
||||
// "error": {
|
||||
// "code": "401",
|
||||
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
|
||||
// }
|
||||
// }
|
||||
|
||||
// api key correct but deployment id is wrong (404)
|
||||
// {
|
||||
// "error": {
|
||||
// "code": "DeploymentNotFound",
|
||||
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
|
||||
// }
|
||||
// }
|
||||
|
||||
// resource name is wrong (node will throw ENOTFOUND)
|
||||
|
||||
// rate limited (429)
|
||||
// TODO: try to reproduce this
|
||||
|
||||
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
|
||||
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
|
||||
const data = error.response.data;
|
||||
const status = data.error.status;
|
||||
const errorType = data.error.code || data.error.type;
|
||||
switch (errorType) {
|
||||
case "DeploymentNotFound":
|
||||
this.log.warn(
|
||||
{ key: key.hash, errorType, error: error.response.data },
|
||||
"Key is revoked or deployment ID is incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
});
|
||||
case "401":
|
||||
this.log.warn(
|
||||
{ key: key.hash, errorType, error: error.response.data },
|
||||
"Key is disabled or incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
});
|
||||
default:
|
||||
this.log.error(
|
||||
{ key: key.hash, errorType, error: error.response.data, status },
|
||||
"Unknown Azure API error while checking key. Please report this."
|
||||
);
|
||||
return this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
}
|
||||
|
||||
const { response, code } = error;
|
||||
if (code === "ENOTFOUND") {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: error.message },
|
||||
"Resource name is probably incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
}
|
||||
|
||||
const { headers, status, data } = response ?? {};
|
||||
this.log.error(
|
||||
{ key: key.hash, status, headers, data, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
private async testModel(key: AzureOpenAIKey) {
|
||||
const { apiKey, deploymentId, resourceName } =
|
||||
AzureOpenAIKeyChecker.getCredentialsFromKey(key);
|
||||
const url = POST_CHAT_COMPLETIONS(resourceName, deploymentId);
|
||||
const testRequest = {
|
||||
max_tokens: 1,
|
||||
stream: false,
|
||||
messages: [{ role: "user", content: "" }],
|
||||
};
|
||||
const { data } = await axios.post(url, testRequest, {
|
||||
headers: { "Content-Type": "application/json", "api-key": apiKey },
|
||||
});
|
||||
|
||||
return getAzureOpenAIModelFamily(data.model);
|
||||
}
|
||||
|
||||
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
|
||||
const data = error.response?.data as any;
|
||||
return data?.error?.code || data?.error?.type;
|
||||
}
|
||||
|
||||
static getCredentialsFromKey(key: AzureOpenAIKey) {
|
||||
const [resourceName, deploymentId, apiKey] = key.key.split(":");
|
||||
if (!resourceName || !deploymentId || !apiKey) {
|
||||
throw new Error(
|
||||
"Invalid Azure credential format. Refer to .env.example and ensure your credentials are in the format RESOURCE_NAME:DEPLOYMENT_ID:API_KEY with commas between each credential set."
|
||||
);
|
||||
}
|
||||
return { resourceName, deploymentId, apiKey };
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,212 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AzureOpenAIModelFamily } from "../../models";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
import { OpenAIModel } from "../openai/provider";
|
||||
import { AzureOpenAIKeyChecker } from "./checker";
|
||||
import { AwsKeyChecker } from "../aws/checker";
|
||||
|
||||
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
|
||||
|
||||
type AzureOpenAIKeyUsage = {
|
||||
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
|
||||
readonly service: "azure";
|
||||
readonly modelFamilies: AzureOpenAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
contentFiltering: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 4000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 250;
|
||||
|
||||
export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
||||
readonly service = "azure";
|
||||
|
||||
private keys: AzureOpenAIKey[] = [];
|
||||
private checker?: AzureOpenAIKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.azureCredentials;
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"AZURE_CREDENTIALS is not set. Azure OpenAI API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: AzureOpenAIKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["azure-gpt4"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
contentFiltering: false,
|
||||
hash: `azu-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
"azure-turboTokens": 0,
|
||||
"azure-gpt4Tokens": 0,
|
||||
"azure-gpt4-32kTokens": 0,
|
||||
"azure-gpt4-turboTokens": 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Azure OpenAI keys.");
|
||||
}
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new AzureOpenAIKeyChecker(
|
||||
this.keys,
|
||||
this.update.bind(this)
|
||||
);
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(model: AzureOpenAIModel) {
|
||||
const neededFamily = getAzureOpenAIModelFamily(model);
|
||||
const availableKeys = this.keys.filter(
|
||||
(k) => !k.isDisabled && k.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error(`No keys available for model family '${neededFamily}'.`);
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: AzureOpenAIKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<AzureOpenAIKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
// TODO: all of this shit is duplicate code
|
||||
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
* concurrent requests running on this key. We don't have any information on
|
||||
* when these requests will resolve, so all we can do is wait a bit and try
|
||||
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||
}
|
||||
|
||||
public recheck() {
|
||||
this.keys.forEach(({ hash }) =>
|
||||
this.update(hash, { lastChecked: 0, isDisabled: false })
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import { OpenAIModel } from "./openai/provider";
|
||||
import { AnthropicModel } from "./anthropic/provider";
|
||||
import { GooglePalmModel } from "./palm/provider";
|
||||
import { AwsBedrockModel } from "./aws/provider";
|
||||
import { AzureOpenAIModel } from "./azure/provider";
|
||||
import { KeyPool } from "./key-pool";
|
||||
import type { ModelFamily } from "../models";
|
||||
|
||||
@@ -13,12 +14,18 @@ export type APIFormat =
|
||||
| "openai-text"
|
||||
| "openai-image";
|
||||
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
||||
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-palm"
|
||||
| "aws"
|
||||
| "azure";
|
||||
export type Model =
|
||||
| OpenAIModel
|
||||
| AnthropicModel
|
||||
| GooglePalmModel
|
||||
| AwsBedrockModel;
|
||||
| AwsBedrockModel
|
||||
| AzureOpenAIModel;
|
||||
|
||||
export interface Key {
|
||||
/** The API key itself. Never log this, use `hash` instead. */
|
||||
@@ -72,3 +79,4 @@ export { AnthropicKey } from "./anthropic/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GooglePalmKey } from "./palm/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
|
||||
@@ -3,14 +3,17 @@ import { logger } from "../../logger";
|
||||
import { Key } from "./index";
|
||||
import { AxiosError } from "axios";
|
||||
|
||||
type KeyCheckerOptions = {
|
||||
type KeyCheckerOptions<TKey extends Key = Key> = {
|
||||
service: string;
|
||||
keyCheckPeriod: number;
|
||||
minCheckInterval: number;
|
||||
}
|
||||
recurringChecksEnabled?: boolean;
|
||||
updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||
};
|
||||
|
||||
export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
protected readonly service: string;
|
||||
protected readonly RECURRING_CHECKS_ENABLED: boolean;
|
||||
/** Minimum time in between any two key checks. */
|
||||
protected readonly MIN_CHECK_INTERVAL: number;
|
||||
/**
|
||||
@@ -19,16 +22,19 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
* than this.
|
||||
*/
|
||||
protected readonly KEY_CHECK_PERIOD: number;
|
||||
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||
protected readonly keys: TKey[] = [];
|
||||
protected log: pino.Logger;
|
||||
protected timeout?: NodeJS.Timeout;
|
||||
protected lastCheck = 0;
|
||||
|
||||
protected constructor(keys: TKey[], opts: KeyCheckerOptions) {
|
||||
protected constructor(keys: TKey[], opts: KeyCheckerOptions<TKey>) {
|
||||
const { service, keyCheckPeriod, minCheckInterval } = opts;
|
||||
this.keys = keys;
|
||||
this.KEY_CHECK_PERIOD = keyCheckPeriod;
|
||||
this.MIN_CHECK_INTERVAL = minCheckInterval;
|
||||
this.RECURRING_CHECKS_ENABLED = opts.recurringChecksEnabled ?? true;
|
||||
this.updateKey = opts.updateKey;
|
||||
this.service = service;
|
||||
this.log = logger.child({ module: "key-checker", service });
|
||||
}
|
||||
@@ -52,31 +58,34 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
* the minimum check interval.
|
||||
*/
|
||||
public scheduleNextCheck() {
|
||||
// Gives each concurrent check a correlation ID to make logs less confusing.
|
||||
const callId = Math.random().toString(36).slice(2, 8);
|
||||
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
|
||||
const checkLog = this.log.child({ callId, timeoutId });
|
||||
|
||||
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
|
||||
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
|
||||
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
|
||||
const numEnabled = enabledKeys.length;
|
||||
const numUnchecked = uncheckedKeys.length;
|
||||
|
||||
clearTimeout(this.timeout);
|
||||
this.timeout = undefined;
|
||||
|
||||
if (enabledKeys.length === 0) {
|
||||
checkLog.warn("All keys are disabled. Key checker stopping.");
|
||||
if (!numEnabled) {
|
||||
checkLog.warn("All keys are disabled. Stopping.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Perform startup checks for any keys that haven't been checked yet.
|
||||
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
|
||||
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
|
||||
if (uncheckedKeys.length > 0) {
|
||||
const keysToCheck = uncheckedKeys.slice(0, 12);
|
||||
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
|
||||
|
||||
if (numUnchecked > 0) {
|
||||
const keycheckBatch = uncheckedKeys.slice(0, 12);
|
||||
|
||||
this.timeout = setTimeout(async () => {
|
||||
try {
|
||||
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
|
||||
await Promise.all(keycheckBatch.map((key) => this.checkKey(key)));
|
||||
} catch (error) {
|
||||
this.log.error({ error }, "Error checking one or more keys.");
|
||||
checkLog.error({ error }, "Error checking one or more keys.");
|
||||
}
|
||||
checkLog.info("Batch complete.");
|
||||
this.scheduleNextCheck();
|
||||
@@ -84,11 +93,18 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
|
||||
checkLog.info(
|
||||
{
|
||||
batch: keysToCheck.map((k) => k.hash),
|
||||
remaining: uncheckedKeys.length - keysToCheck.length,
|
||||
batch: keycheckBatch.map((k) => k.hash),
|
||||
remaining: uncheckedKeys.length - keycheckBatch.length,
|
||||
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
|
||||
},
|
||||
"Scheduled batch check."
|
||||
"Scheduled batch of initial checks."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.RECURRING_CHECKS_ENABLED) {
|
||||
checkLog.info(
|
||||
"Initial checks complete and recurring checks are disabled for this service. Stopping."
|
||||
);
|
||||
return;
|
||||
}
|
||||
@@ -106,14 +122,35 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
);
|
||||
|
||||
const delay = nextCheck - Date.now();
|
||||
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
|
||||
this.timeout = setTimeout(
|
||||
() => this.checkKey(oldestKey).then(() => this.scheduleNextCheck()),
|
||||
delay
|
||||
);
|
||||
checkLog.debug(
|
||||
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
|
||||
"Scheduled single key check."
|
||||
"Scheduled next recurring check."
|
||||
);
|
||||
}
|
||||
|
||||
protected abstract checkKey(key: TKey): Promise<void>;
|
||||
public async checkKey(key: TKey): Promise<void> {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
|
||||
try {
|
||||
await this.testKeyOrFail(key);
|
||||
} catch (error) {
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
}
|
||||
|
||||
protected abstract testKeyOrFail(key: TKey): Promise<void>;
|
||||
|
||||
protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import { GooglePalmKeyProvider } from "./palm/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { ModelFamily } from "../models";
|
||||
import { assertNever } from "../utils";
|
||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||
|
||||
@@ -25,6 +26,7 @@ export class KeyPool {
|
||||
this.keyProviders.push(new AnthropicKeyProvider());
|
||||
this.keyProviders.push(new GooglePalmKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||
}
|
||||
|
||||
public init() {
|
||||
@@ -124,6 +126,8 @@ export class KeyPool {
|
||||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
return "aws";
|
||||
} else if (model.startsWith("azure")) {
|
||||
return "azure";
|
||||
}
|
||||
throw new Error(`Unknown service for model '${model}'`);
|
||||
}
|
||||
@@ -142,6 +146,11 @@ export class KeyPool {
|
||||
return "google-palm";
|
||||
case "aws-claude":
|
||||
return "aws";
|
||||
case "azure-turbo":
|
||||
case "azure-gpt4":
|
||||
case "azure-gpt4-32k":
|
||||
case "azure-gpt4-turbo":
|
||||
return "azure";
|
||||
default:
|
||||
assertNever(modelFamily);
|
||||
}
|
||||
|
||||
@@ -27,65 +27,41 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
|
||||
|
||||
export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
||||
private readonly cloneKey: CloneFn;
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "openai",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
this.cloneKey = cloneFn;
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async checkKey(key: OpenAIKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
if (isInitialCheck) {
|
||||
const [provisionedModels, livenessTest] = await Promise.all([
|
||||
this.getProvisionedModels(key),
|
||||
this.testLiveness(key),
|
||||
this.maybeCreateOrganizationClones(key),
|
||||
]);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
isTrial: livenessTest.rateLimit <= 250,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
} else {
|
||||
// No updates needed as models and trial status generally don't change.
|
||||
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
||||
this.updateKey(key.hash, {});
|
||||
}
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
// touch the key so we don't check it again for a while
|
||||
protected async testKeyOrFail(key: OpenAIKey) {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
const [provisionedModels, livenessTest] = await Promise.all([
|
||||
this.getProvisionedModels(key),
|
||||
this.testLiveness(key),
|
||||
this.maybeCreateOrganizationClones(key),
|
||||
]);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
isTrial: livenessTest.rateLimit <= 250,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
} else {
|
||||
// No updates needed as models and trial status generally don't change.
|
||||
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.log.info(
|
||||
{ key: key.hash },
|
||||
"Recurring keychecks are disabled, no-op."
|
||||
);
|
||||
// this.scheduleNextCheck();
|
||||
}
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
private async getProvisionedModels(
|
||||
@@ -138,6 +114,17 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
||||
.filter(({ is_default }) => !is_default)
|
||||
.map(({ id }) => id);
|
||||
this.cloneKey(key.hash, ids);
|
||||
|
||||
// It's possible that the keychecker may be stopped if all non-cloned keys
|
||||
// happened to be unusable, in which case this clnoe will never be checked
|
||||
// unless we restart the keychecker.
|
||||
if (!this.timeout) {
|
||||
this.log.warn(
|
||||
{ parent: key.hash },
|
||||
"Restarting key checker to check cloned keys."
|
||||
);
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: OpenAIKey, error: AxiosError) {
|
||||
|
||||
@@ -217,17 +217,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
// logger.debug(
|
||||
// {
|
||||
// byPriority: keysByPriority.map((k) => ({
|
||||
// hash: k.hash,
|
||||
// isRateLimited: now - k.rateLimitedAt < rateLimitThreshold,
|
||||
// modelFamilies: k.modelFamilies,
|
||||
// })),
|
||||
// },
|
||||
// "Keys sorted by priority"
|
||||
// );
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
|
||||
+34
-2
@@ -2,15 +2,25 @@
|
||||
|
||||
import pino from "pino";
|
||||
|
||||
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k" | "gpt4-turbo" | "dall-e";
|
||||
export type OpenAIModelFamily =
|
||||
| "turbo"
|
||||
| "gpt4"
|
||||
| "gpt4-32k"
|
||||
| "gpt4-turbo"
|
||||
| "dall-e";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type GooglePalmModelFamily = "bison";
|
||||
export type AwsBedrockModelFamily = "aws-claude";
|
||||
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
OpenAIModelFamily,
|
||||
"dall-e"
|
||||
>}`;
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GooglePalmModelFamily
|
||||
| AwsBedrockModelFamily;
|
||||
| AwsBedrockModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
|
||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||
@@ -23,6 +33,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"claude",
|
||||
"bison",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
"azure-gpt4-turbo",
|
||||
] as const);
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
@@ -64,6 +78,24 @@ export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
||||
return "aws-claude";
|
||||
}
|
||||
|
||||
export function getAzureOpenAIModelFamily(
|
||||
model: string,
|
||||
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
|
||||
): AzureOpenAIModelFamily {
|
||||
// Azure model names omit periods. addAzureKey also prepends "azure-" to the
|
||||
// model name to route the request the correct keyprovider, so we need to
|
||||
// remove that as well.
|
||||
const modified = model
|
||||
.replace("gpt-35-turbo", "gpt-3.5-turbo")
|
||||
.replace("azure-", "");
|
||||
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
|
||||
if (modified.match(regex)) {
|
||||
return `azure-${family}` as AzureOpenAIModelFamily;
|
||||
}
|
||||
}
|
||||
return defaultFamily;
|
||||
}
|
||||
|
||||
export function assertIsKnownModelFamily(
|
||||
modelFamily: string
|
||||
): asserts modelFamily is ModelFamily {
|
||||
|
||||
@@ -12,6 +12,7 @@ import schedule from "node-schedule";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import { config, getFirebaseApp } from "../../config";
|
||||
import {
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGooglePalmModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
@@ -34,6 +35,10 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
||||
claude: 0,
|
||||
bison: 0,
|
||||
"aws-claude": 0,
|
||||
"azure-turbo": 0,
|
||||
"azure-gpt4": 0,
|
||||
"azure-gpt4-turbo": 0,
|
||||
"azure-gpt4-32k": 0,
|
||||
};
|
||||
|
||||
const users: Map<string, User> = new Map();
|
||||
@@ -382,6 +387,9 @@ function getModelFamilyForQuotaUsage(
|
||||
model: string,
|
||||
api: APIFormat
|
||||
): ModelFamily {
|
||||
// TODO: this seems incorrect
|
||||
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
|
||||
|
||||
switch (api) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
|
||||
Reference in New Issue
Block a user