From c6467b02f3057dfcc2c03ead4b851eb824ec27bd Mon Sep 17 00:00:00 2001 From: nai-degen Date: Sat, 10 Aug 2024 14:41:25 -0500 Subject: [PATCH] adds AWS mistral model families and checker IDs --- src/shared/key-management/aws/checker.ts | 5 ++++ src/shared/models.ts | 29 ++++++++++++++++++++---- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/shared/key-management/aws/checker.ts b/src/shared/key-management/aws/checker.ts index de22c64..578e24a 100644 --- a/src/shared/key-management/aws/checker.ts +++ b/src/shared/key-management/aws/checker.ts @@ -58,6 +58,11 @@ export class AwsKeyChecker extends KeyCheckerBase { this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key), this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key), this.invokeModel("anthropic.claude-3-5-sonnet-20240620-v1:0", key), + this.invokeModel("mistral.mistral-7b-instruct-v0:2", key), + this.invokeModel("mistral.mixtral-8x7b-instruct-v0:1", key), + this.invokeModel("mistral.mistral-large-2402-v1:0", key), + this.invokeModel("mistral.mistral-large-2407-v1:0", key), + this.invokeModel("mistral.mistral-small-2402-v1:0", key), ]; } diff --git a/src/shared/models.ts b/src/shared/models.ts index 2004374..991ec94 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -32,7 +32,9 @@ export type MistralAIModelFamily = // mistral changes their model classes frequently so these no longer // correspond to specific models. consider them rough pricing tiers. "mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large"; -export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus"; +export type AwsBedrockModelFamily = `aws-${ + | AnthropicModelFamily + | MistralAIModelFamily}`; export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus"; export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`; export type ModelFamily = @@ -64,6 +66,10 @@ export const MODEL_FAMILIES = (( "mistral-large", "aws-claude", "aws-claude-opus", + "aws-mistral-tiny", + "aws-mistral-small", + "aws-mistral-medium", + "aws-mistral-large", "gcp-claude", "gcp-claude-opus", "azure-turbo", @@ -99,6 +105,10 @@ export const MODEL_FAMILY_SERVICE: { "claude-opus": "anthropic", "aws-claude": "aws", "aws-claude-opus": "aws", + "aws-mistral-tiny": "aws", + "aws-mistral-small": "aws", + "aws-mistral-medium": "aws", + "aws-mistral-large": "aws", "gcp-claude": "gcp", "gcp-claude-opus": "gcp", "azure-turbo": "azure", @@ -180,8 +190,16 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily { } export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily { - if (model.includes("opus")) return "aws-claude-opus"; - return "aws-claude"; + // remove vendor and version from AWS model ids + // 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620' + const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d)?(:\d+)*$/, "$2"); + + if (["claude", "anthropic"].some((x) => model.includes(x))) { + return `aws-${getClaudeModelFamily(deAwsified)}`; + } else if (model.includes("tral")) { + return `aws-${getMistralAIModelFamily(deAwsified)}`; + } + return `aws-claude`; } export function getGcpModelFamily(model: string): GcpModelFamily { @@ -223,8 +241,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily { const model = req.body.model ?? "gpt-3.5-turbo"; let modelFamily: ModelFamily; - // Weird special case for AWS/GCP/Azure because they serve multiple models from - // different vendors, even if currently only one is supported. + // Weird special case for AWS/GCP/Azure because they serve models with + // different API formats, so the outbound API alone is not sufficient to + // determine the partition. if (req.service === "aws") { modelFamily = getAwsBedrockModelFamily(model); } else if (req.service === "gcp") {