diff --git a/src/shared/key-management/aws/checker.ts b/src/shared/key-management/aws/checker.ts index 578e24a..d1c77dd 100644 --- a/src/shared/key-management/aws/checker.ts +++ b/src/shared/key-management/aws/checker.ts @@ -5,9 +5,21 @@ import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios"; import { URL } from "url"; import { KeyCheckerBase } from "../key-checker-base"; import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider"; -import { AwsBedrockModelFamily } from "../../models"; +import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models"; import { config } from "../../../config"; +const KNOWN_MODEL_IDS = [ + "anthropic.claude-v2", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", + "mistral.mistral-large-2407-v1:0", + "mistral.mistral-small-2402-v1:0", +]; const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes const AMZ_HOST = @@ -47,41 +59,20 @@ export class AwsKeyChecker extends KeyCheckerBase { } protected async testKeyOrFail(key: AwsBedrockKey) { - // Only check models on startup. For now all models must be available to - // the proxy because we don't route requests to different keys. - let checks: Promise[] = []; const isInitialCheck = !key.lastChecked; - if (isInitialCheck) { - checks = [ - this.invokeModel("anthropic.claude-v2", key), - this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key), - this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key), - this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key), - this.invokeModel("anthropic.claude-3-5-sonnet-20240620-v1:0", key), - this.invokeModel("mistral.mistral-7b-instruct-v0:2", key), - this.invokeModel("mistral.mixtral-8x7b-instruct-v0:1", key), - this.invokeModel("mistral.mistral-large-2402-v1:0", key), - this.invokeModel("mistral.mistral-large-2407-v1:0", key), - this.invokeModel("mistral.mistral-small-2402-v1:0", key), - ]; - } - - checks.unshift(this.checkLoggingConfiguration(key)); - - const [_logging, claudeV2, sonnet, haiku, opus, sonnet35] = - await Promise.all(checks); - - this.log.debug( - { key: key.hash, _logging, claudeV2, sonnet, haiku, opus, sonnet35 }, - "AWS model tests complete." - ); if (isInitialCheck) { - const families: AwsBedrockModelFamily[] = []; - if (claudeV2 || sonnet || sonnet35 || haiku) families.push("aws-claude"); - if (opus) families.push("aws-claude-opus"); + const checks = await Promise.all( + KNOWN_MODEL_IDS.map(async (model) => { + const success = await this.invokeModel(model, key); + return { model, success }; + }) + ); + const modelIds = checks + .filter(({ success }) => success) + .map(({ model }) => model); - if (families.length === 0) { + if (modelIds.length === 0) { this.log.warn( { key: key.hash }, "Key does not have access to any models; disabling." @@ -90,20 +81,19 @@ export class AwsKeyChecker extends KeyCheckerBase { } this.updateKey(key.hash, { - sonnetEnabled: sonnet, - haikuEnabled: haiku, - sonnet35Enabled: sonnet35, - modelFamilies: families, + modelIds, + modelFamilies: Array.from( + new Set(modelIds.map(getAwsBedrockModelFamily)) + ), }); } - this.log.info( + this.log.debug( { key: key.hash, - sonnet, - haiku, - families: key.modelFamilies, logged: key.awsLoggingStatus, + families: key.modelFamilies, + models: key.modelIds, }, "Checked key." ); @@ -174,7 +164,10 @@ export class AwsKeyChecker extends KeyCheckerBase { * key has access to the model, false if it does not. Throws an error if the * key is disabled. */ - private async invokeModel(model: string, key: AwsBedrockKey) { + private async invokeModel( + model: string, + key: AwsBedrockKey + ): Promise { const creds = AwsKeyChecker.getCredentialsFromKey(key); // This is not a valid invocation payload, but a 400 response indicates that // the principal at least has permission to invoke the model. @@ -208,7 +201,7 @@ export class AwsKeyChecker extends KeyCheckerBase { ) { return false; } - + // ResourceNotFound typically indicates that the tested model cannot be used // on the configured region for this set of credentials. if (status === 404) { diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index fe23809..68378ff 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -1,5 +1,5 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models"; @@ -13,10 +13,6 @@ type AwsBedrockKeyUsage = { export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage { readonly service: "aws"; readonly modelFamilies: AwsBedrockModelFamily[]; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** The time until which this key is rate limited. */ - rateLimitedUntil: number; /** * The confirmed logging status of this key. This is "unknown" until we * receive a response from the AWS API. Keys which are logged, or not @@ -24,9 +20,11 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage { * set. */ awsLoggingStatus: "unknown" | "disabled" | "enabled"; - sonnetEnabled: boolean; - haikuEnabled: boolean; - sonnet35Enabled: boolean; + // TODO: replace with list of model ids + // sonnetEnabled: boolean; + // haikuEnabled: boolean; + // sonnet35Enabled: boolean; + modelIds: string[]; } /** @@ -76,11 +74,16 @@ export class AwsBedrockKeyProvider implements KeyProvider { .digest("hex") .slice(0, 8)}`, lastChecked: 0, - sonnetEnabled: true, - haikuEnabled: false, - sonnet35Enabled: false, + modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"], + // sonnetEnabled: true, + // haikuEnabled: false, + // sonnet35Enabled: false, ["aws-claudeTokens"]: 0, ["aws-claude-opusTokens"]: 0, + ["aws-mistral-tinyTokens"]: 0, + ["aws-mistral-smallTokens"]: 0, + ["aws-mistral-mediumTokens"]: 0, + ["aws-mistral-largeTokens"]: 0, }; this.keys.push(newKey); } @@ -99,41 +102,35 @@ export class AwsBedrockKeyProvider implements KeyProvider { } public get(model: string) { + let neededVariantId = model; + // The only AWS model that breaks naming convention is Claude v2. Anthropic + // calls this claude-2 but AWS calls it claude-v2. + if (model.includes("claude-2")) neededVariantId = "claude-v2"; const neededFamily = getAwsBedrockModelFamily(model); - // this is a horrible mess - // each of these should be separate model families, but adding model - // families is not low enough friction for the rate at which aws claude - // model variants are added. - const needsSonnet35 = - model.includes("claude-3-5-sonnet") && neededFamily === "aws-claude"; - const needsSonnet = - !needsSonnet35 && - model.includes("sonnet") && - neededFamily === "aws-claude"; - const needsHaiku = model.includes("haiku") && neededFamily === "aws-claude"; - const availableKeys = this.keys.filter((k) => { - const isNotLogged = k.awsLoggingStatus !== "enabled"; + // Select keys which return ( + // are enabled !k.isDisabled && - (isNotLogged || config.allowAwsLogging) && - (k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under aws-claude, while opus is not - (k.haikuEnabled || !needsHaiku) && - (k.sonnet35Enabled || !needsSonnet35) && - k.modelFamilies.includes(neededFamily) + // are not logged, unless policy allows it + (config.allowAwsLogging || k.awsLoggingStatus !== "enabled") && + // have access to the model family we need + k.modelFamilies.includes(neededFamily) && + // have access to the specific variant we need + // note that requests can be made for the AWS ID or original vendor ID; + // all vendor IDs are substrings of the AWS ID. + k.modelIds.some((m) => m.includes(neededVariantId)) ); }); this.log.debug( { - model, - neededFamily, - needsSonnet, - needsHaiku, - needsSonnet35, - availableKeys: availableKeys.length, + requestedModel: model, + selectedVariant: neededVariantId, + selectedFamily: neededFamily, totalKeys: this.keys.length, + availableKeys: availableKeys.length, }, "Selecting AWS key" ); @@ -195,22 +192,7 @@ export class AwsBedrockKeyProvider implements KeyProvider { key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens; } - public getLockoutPeriod() { - // TODO: same exact behavior for three providers, should be refactored - const activeKeys = this.keys.filter((k) => !k.isDisabled); - // Don't lock out if there are no keys available or the queue will stall. - // Just let it through so the add-key middleware can throw an error. - if (activeKeys.length === 0) return 0; - - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); - const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; - - if (anyNotRateLimited) return 0; - - // If all keys are rate-limited, return time until the first key is ready. - return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); - } + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); /** * This is called when we receive a 429, which means there are already five diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index c1b3f6a..56ef371 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -85,8 +85,9 @@ export function createGenericGetLockoutPeriod( export const keyPool = new KeyPool(); export { AnthropicKey } from "./anthropic/provider"; -export { OpenAIKey } from "./openai/provider"; -export { GoogleAIKey } from "././google-ai/provider"; export { AwsBedrockKey } from "./aws/provider"; export { GcpKey } from "./gcp/provider"; export { AzureOpenAIKey } from "./azure/provider"; +export { GoogleAIKey } from "././google-ai/provider"; +export { MistralAIKey } from "./mistral-ai/provider"; +export { OpenAIKey } from "./openai/provider"; diff --git a/src/shared/models.ts b/src/shared/models.ts index 991ec94..12ff589 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -192,7 +192,7 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily { export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily { // remove vendor and version from AWS model ids // 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620' - const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d)?(:\d+)*$/, "$2"); + const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d+)?(:\d+)*$/, "$2"); if (["claude", "anthropic"].some((x) => model.includes(x))) { return `aws-${getClaudeModelFamily(deAwsified)}`;