From dc1b5730204d45959643e2e3c7202a7965d171c7 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Sun, 11 Aug 2024 12:21:28 -0500 Subject: [PATCH] small KeyProvider#get refactor --- src/shared/key-management/aws/provider.ts | 33 +++--------- src/shared/key-management/azure/provider.ts | 37 ++++--------- src/shared/key-management/gcp/provider.ts | 54 +++---------------- .../key-management/google-ai/provider.ts | 28 ++-------- .../key-management/mistral-ai/provider.ts | 33 +++--------- src/shared/key-management/prioritize-keys.ts | 24 +++++++++ 6 files changed, 57 insertions(+), 152 deletions(-) create mode 100644 src/shared/key-management/prioritize-keys.ts diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index 53dcfc9..97d320d 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -1,10 +1,11 @@ import crypto from "crypto"; -import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models"; -import { AwsKeyChecker } from "./checker"; import { PaymentRequiredError } from "../../errors"; +import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; +import { AwsKeyChecker } from "./checker"; type AwsBedrockKeyUsage = { [K in AwsBedrockModelFamily as `${K}Tokens`]: number; @@ -137,30 +138,8 @@ export class AwsBedrockKeyProvider implements KeyProvider { ); } - // (largely copied from the OpenAI provider, without trial key support) - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + const selectedKey = prioritizeKeys(availableKeys)[0]; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } diff --git a/src/shared/key-management/azure/provider.ts b/src/shared/key-management/azure/provider.ts index 8a7e48a..28439fa 100644 --- a/src/shared/key-management/azure/provider.ts +++ b/src/shared/key-management/azure/provider.ts @@ -1,10 +1,13 @@ import crypto from "crypto"; -import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; import { config } from "../../../config"; -import { PaymentRequiredError } from "../../errors"; import { logger } from "../../../logger"; -import type { AzureOpenAIModelFamily } from "../../models"; -import { getAzureOpenAIModelFamily } from "../../models"; +import { PaymentRequiredError } from "../../errors"; +import { + AzureOpenAIModelFamily, + getAzureOpenAIModelFamily, +} from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; import { AzureOpenAIKeyChecker } from "./checker"; type AzureOpenAIKeyUsage = { @@ -101,30 +104,8 @@ export class AzureOpenAIKeyProvider implements KeyProvider { ); } - // (largely copied from the OpenAI provider, without trial key support) - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + const selectedKey = prioritizeKeys(availableKeys)[0]; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } diff --git a/src/shared/key-management/gcp/provider.ts b/src/shared/key-management/gcp/provider.ts index 8e9c9ab..e3f72ef 100644 --- a/src/shared/key-management/gcp/provider.ts +++ b/src/shared/key-management/gcp/provider.ts @@ -1,10 +1,11 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { GcpModelFamily, getGcpModelFamily } from "../../models"; -import { GcpKeyChecker } from "./checker"; import { PaymentRequiredError } from "../../errors"; +import { GcpModelFamily, getGcpModelFamily } from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; +import { GcpKeyChecker } from "./checker"; type GcpKeyUsage = { [K in GcpModelFamily as `${K}Tokens`]: number; @@ -13,10 +14,6 @@ type GcpKeyUsage = { export interface GcpKey extends Key, GcpKeyUsage { readonly service: "gcp"; readonly modelFamilies: GcpModelFamily[]; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** The time until which this key is rate limited. */ - rateLimitedUntil: number; sonnetEnabled: boolean; haikuEnabled: boolean; sonnet35Enabled: boolean; @@ -134,30 +131,8 @@ export class GcpKeyProvider implements KeyProvider { ); } - // (largely copied from the OpenAI provider, without trial key support) - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + const selectedKey = prioritizeKeys(availableKeys)[0]; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } @@ -185,22 +160,7 @@ export class GcpKeyProvider implements KeyProvider { key[`${getGcpModelFamily(model)}Tokens`] += tokens; } - public getLockoutPeriod() { - // TODO: same exact behavior for three providers, should be refactored - const activeKeys = this.keys.filter((k) => !k.isDisabled); - // Don't lock out if there are no keys available or the queue will stall. - // Just let it through so the add-key middleware can throw an error. - if (activeKeys.length === 0) return 0; - - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); - const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; - - if (anyNotRateLimited) return 0; - - // If all keys are rate-limited, return time until the first key is ready. - return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); - } + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); /** * This is called when we receive a 429, which means there are already five diff --git a/src/shared/key-management/google-ai/provider.ts b/src/shared/key-management/google-ai/provider.ts index e94460b..a7abfc1 100644 --- a/src/shared/key-management/google-ai/provider.ts +++ b/src/shared/key-management/google-ai/provider.ts @@ -1,9 +1,10 @@ import crypto from "crypto"; -import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models"; import { PaymentRequiredError } from "../../errors"; +import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; import { GoogleAIKeyChecker } from "./checker"; // Note that Google AI is not the same as Vertex AI, both are provided by @@ -108,29 +109,10 @@ export class GoogleAIKeyProvider implements KeyProvider { throw new PaymentRequiredError("No Google AI keys available"); } - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); + const keysByPriority = prioritizeKeys(availableKeys); const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } diff --git a/src/shared/key-management/mistral-ai/provider.ts b/src/shared/key-management/mistral-ai/provider.ts index 20f9dae..a8460e6 100644 --- a/src/shared/key-management/mistral-ai/provider.ts +++ b/src/shared/key-management/mistral-ai/provider.ts @@ -1,10 +1,11 @@ import crypto from "crypto"; -import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models"; -import { MistralAIKeyChecker } from "./checker"; import { HttpError } from "../../errors"; +import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; +import { MistralAIKeyChecker } from "./checker"; type MistralAIKeyUsage = { [K in MistralAIModelFamily as `${K}Tokens`]: number; @@ -94,30 +95,8 @@ export class MistralAIKeyProvider implements KeyProvider { throw new HttpError(402, "No Mistral AI keys available"); } - // (largely copied from the OpenAI provider, without trial key support) - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + const selectedKey = prioritizeKeys(availableKeys)[0]; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } diff --git a/src/shared/key-management/prioritize-keys.ts b/src/shared/key-management/prioritize-keys.ts new file mode 100644 index 0000000..cf52995 --- /dev/null +++ b/src/shared/key-management/prioritize-keys.ts @@ -0,0 +1,24 @@ +import { Key } from "./index"; + +export function prioritizeKeys(keys: T[]) { + // Sorts keys from highest priority to lowest priority, where priority is: + // 1. Keys which are not rate limited + // a. If all keys were rate limited recently, select the least-recently + // rate limited key. + // 2. Keys which have not been used in the longest time + + const now = Date.now(); + + return keys.sort((a, b) => { + const aRateLimited = now - a.rateLimitedAt < a.rateLimitedUntil; + const bRateLimited = now - b.rateLimitedAt < b.rateLimitedUntil; + + if (aRateLimited && !bRateLimited) return 1; + if (!aRateLimited && bRateLimited) return -1; + if (aRateLimited && bRateLimited) { + return a.rateLimitedAt - b.rateLimitedAt; + } + + return a.lastUsed - b.lastUsed; + }); +}