diff --git a/package-lock.json b/package-lock.json index 99971af..a7fea0c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,6 +11,7 @@ "dependencies": { "@anthropic-ai/tokenizer": "^0.0.4", "@aws-crypto/sha256-js": "^5.2.0", + "@huggingface/jinja": "^0.3.0", "@node-rs/argon2": "^1.8.3", "@smithy/eventstream-codec": "^2.1.3", "@smithy/eventstream-serde-node": "^2.1.3", @@ -18,7 +19,7 @@ "@smithy/signature-v4": "^2.1.3", "@smithy/types": "^2.10.1", "@smithy/util-utf8": "^2.1.1", - "axios": "^1.3.5", + "axios": "^1.7.4", "better-sqlite3": "^10.0.0", "check-disk-space": "^3.4.0", "cookie-parser": "^1.4.6", @@ -866,6 +867,14 @@ "node": ">=6" } }, + "node_modules/@huggingface/jinja": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.3.0.tgz", + "integrity": "sha512-GLJzso0M07ZncFkrJMIXVU4os6GFbPocD4g8fMQPMGJubf48FtGOsUORH2rtFdXPIPelz8SLBMn8ZRmOTwXm9Q==", + "engines": { + "node": ">=18" + } + }, "node_modules/@isaacs/cliui": { "version": "8.0.2", "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", @@ -1887,11 +1896,11 @@ } }, "node_modules/axios": { - "version": "1.6.1", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", - "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", + "version": "1.7.4", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.4.tgz", + "integrity": "sha512-DukmaFRnY6AzAALSH4J2M3k6PkaC+MfaAGdEERRWcC9q3/TWQwLpHR8ZRLKTdQ3aBDL64EdluRDjJqKw+BPZEw==", "dependencies": { - "follow-redirects": "^1.15.0", + "follow-redirects": "^1.15.6", "form-data": "^4.0.0", "proxy-from-env": "^1.1.0" } diff --git a/package.json b/package.json index 9e4b681..852824e 100644 --- a/package.json +++ b/package.json @@ -20,6 +20,7 @@ "dependencies": { "@anthropic-ai/tokenizer": "^0.0.4", "@aws-crypto/sha256-js": "^5.2.0", + "@huggingface/jinja": "^0.3.0", "@node-rs/argon2": "^1.8.3", "@smithy/eventstream-codec": "^2.1.3", "@smithy/eventstream-serde-node": "^2.1.3", @@ -27,7 +28,7 @@ "@smithy/signature-v4": "^2.1.3", "@smithy/types": "^2.10.1", "@smithy/util-utf8": "^2.1.1", - "axios": "^1.3.5", + "axios": "^1.7.4", "better-sqlite3": "^10.0.0", "check-disk-space": "^3.4.0", "cookie-parser": "^1.4.6", diff --git a/scripts/test-aws-signing.ts b/scripts/test-aws-signing.ts new file mode 100644 index 0000000..3d4869c --- /dev/null +++ b/scripts/test-aws-signing.ts @@ -0,0 +1,118 @@ +// uses the aws sdk to sign a request, then uses axios to send it to the bedrock REST API manually +import axios from "axios"; +import { Sha256 } from "@aws-crypto/sha256-js"; +import { SignatureV4 } from "@smithy/signature-v4"; +import { HttpRequest } from "@smithy/protocol-http"; + +const AWS_ACCESS_KEY_ID = process.env.AWS_ACCESS_KEY_ID!; +const AWS_SECRET_ACCESS_KEY = process.env.AWS_SECRET_ACCESS_KEY!; + +// Copied from amazon bedrock docs + +// List models +// ListFoundationModels +// Service: Amazon Bedrock +// List of Bedrock foundation models that you can use. For more information, see Foundation models in the +// Bedrock User Guide. +// Request Syntax +// GET /foundation-models? +// byCustomizationType=byCustomizationType&byInferenceType=byInferenceType&byOutputModality=byOutputModality&byProvider=byProvider +// HTTP/1.1 +// URI Request Parameters +// The request uses the following URI parameters. +// byCustomizationType (p. 38) +// List by customization type. +// Valid Values: FINE_TUNING +// byInferenceType (p. 38) +// List by inference type. +// Valid Values: ON_DEMAND | PROVISIONED +// byOutputModality (p. 38) +// List by output modality type. +// Valid Values: TEXT | IMAGE | EMBEDDING +// byProvider (p. 38) +// A Bedrock model provider. +// Pattern: ^[a-z0-9-]{1,63}$ +// Request Body +// The request does not have a request body + +// Run inference on a text model +// Send an invoke request to run inference on a Titan Text G1 - Express model. We set the accept +// parameter to accept any content type in the response. +// POST https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke +// -H accept: */* +// -H content-type: application/json +// Payload +// {"inputText": "Hello world"} +// Example response +// Response for the above request. +// -H content-type: application/json +// Payload +// + +const AMZ_REGION = "us-east-1"; +const AMZ_HOST = "invoke-bedrock.us-east-1.amazonaws.com"; + +async function listModels() { + const httpRequest = new HttpRequest({ + method: "GET", + protocol: "https:", + hostname: AMZ_HOST, + path: "/foundation-models", + headers: { ["Host"]: AMZ_HOST }, + }); + + const signedRequest = await signRequest(httpRequest); + const response = await axios.get( + `https://${signedRequest.hostname}${signedRequest.path}`, + { headers: signedRequest.headers } + ); + console.log(response.data); +} + +async function invokeModel() { + const model = "anthropic.claude-v1"; + const httpRequest = new HttpRequest({ + method: "POST", + protocol: "https:", + hostname: AMZ_HOST, + path: `/model/${model}/invoke`, + headers: { + ["Host"]: AMZ_HOST, + ["accept"]: "*/*", + ["content-type"]: "application/json", + }, + body: JSON.stringify({ + temperature: 0.5, + prompt: "\n\nHuman:Hello world\n\nAssistant:", + max_tokens_to_sample: 10, + }), + }); + console.log("httpRequest", httpRequest); + + const signedRequest = await signRequest(httpRequest); + const response = await axios.post( + `https://${signedRequest.hostname}${signedRequest.path}`, + signedRequest.body, + { headers: signedRequest.headers } + ); + console.log(response.status); + console.log(response.headers); + console.log(response.data); + console.log("full url", response.request.res.responseUrl); +} + +async function signRequest(request: HttpRequest) { + const signer = new SignatureV4({ + sha256: Sha256, + credentials: { + accessKeyId: AWS_ACCESS_KEY_ID, + secretAccessKey: AWS_SECRET_ACCESS_KEY, + }, + region: AMZ_REGION, + service: "bedrock", + }); + return await signer.sign(request, { signingDate: new Date() }); +} + +// listModels(); +// invokeModel(); diff --git a/src/config.ts b/src/config.ts index 8318522..7d4a5a1 100644 --- a/src/config.ts +++ b/src/config.ts @@ -428,31 +428,10 @@ export const config: Config = { ["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"], 400 ), - allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [ - "turbo", - "gpt4", - "gpt4-32k", - "gpt4-turbo", - "gpt4o", - "claude", - "claude-opus", - "gemini-flash", - "gemini-pro", - "gemini-ultra", - "mistral-tiny", - "mistral-small", - "mistral-medium", - "mistral-large", - "aws-claude", - "aws-claude-opus", - "gcp-claude", - "gcp-claude-opus", - "azure-turbo", - "azure-gpt4", - "azure-gpt4-32k", - "azure-gpt4-turbo", - "azure-gpt4o", - ]), + allowedModelFamilies: getEnvWithDefault( + "ALLOWED_MODEL_FAMILIES", + getDefaultModelFamilies() + ), rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")), rejectMessage: getEnvWithDefault( "REJECT_MESSAGE", @@ -801,3 +780,7 @@ function parseCsv(val: string): string[] { const matches = val.match(regex) || []; return matches.map((item) => item.replace(/^"|"$/g, "").trim()); } + +function getDefaultModelFamilies(): ModelFamily[] { + return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[]; +} diff --git a/src/info-page.ts b/src/info-page.ts index 2f0a7f1..580c4b6 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -29,6 +29,10 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { "mistral-large": "Mistral Large", "aws-claude": "AWS Claude (Sonnet)", "aws-claude-opus": "AWS Claude (Opus)", + "aws-mistral-tiny": "AWS Mistral 7B", + "aws-mistral-small": "AWS Mistral Nemo", + "aws-mistral-medium": "AWS Mistral Medium", + "aws-mistral-large": "AWS Mistral Large", "gcp-claude": "GCP Claude (Sonnet)", "gcp-claude-opus": "GCP Claude (Opus)", "azure-turbo": "Azure GPT-3.5 Turbo", @@ -41,7 +45,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { const converter = new showdown.Converter(); const customGreeting = fs.existsSync("greeting.md") - ? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}` + ? `
${fs.readFileSync("greeting.md", "utf8")}
` : ""; let infoPageHtml: string | undefined; let infoPageLastUpdated = 0; diff --git a/src/proxy/add-v1.ts b/src/proxy/add-v1.ts new file mode 100644 index 0000000..d0d8620 --- /dev/null +++ b/src/proxy/add-v1.ts @@ -0,0 +1,9 @@ +import { NextFunction, Request, Response } from "express"; + +export function addV1(req: Request, res: Response, next: NextFunction) { + // Clients don't consistently use the /v1 prefix so we'll add it for them. + if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) { + req.url = `/v1${req.url}`; + } + next(); +} diff --git a/src/proxy/aws-claude.ts b/src/proxy/aws-claude.ts new file mode 100644 index 0000000..a00e4e6 --- /dev/null +++ b/src/proxy/aws-claude.ts @@ -0,0 +1,253 @@ +import { Request, RequestHandler, Router } from "express"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { v4 } from "uuid"; +import { logger } from "../logger"; +import { createQueueMiddleware } from "./queue"; +import { ipLimiter } from "./rate-limit"; +import { handleProxyError } from "./middleware/common"; +import { + createPreprocessorMiddleware, + signAwsRequest, + finalizeSignedRequest, + createOnProxyReqHandler, +} from "./middleware/request"; +import { + ProxyResHandlerWithBody, + createOnProxyResHandler, +} from "./middleware/response"; +import { + transformAnthropicChatResponseToAnthropicText, + transformAnthropicChatResponseToOpenAI, +} from "./anthropic"; + +/** Only used for non-streaming requests. */ +const awsResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + let newBody = body; + switch (`${req.inboundApi}<-${req.outboundApi}`) { + case "openai<-anthropic-text": + req.log.info("Transforming Anthropic Text back to OpenAI format"); + newBody = transformAwsTextResponseToOpenAI(body, req); + break; + case "openai<-anthropic-chat": + req.log.info("Transforming AWS Anthropic Chat back to OpenAI format"); + newBody = transformAnthropicChatResponseToOpenAI(body); + break; + case "anthropic-text<-anthropic-chat": + req.log.info("Transforming AWS Anthropic Chat back to Text format"); + newBody = transformAnthropicChatResponseToAnthropicText(body); + break; + } + + // AWS does not always confirm the model in the response, so we have to add it + if (!newBody.model && req.body.model) { + newBody.model = req.body.model; + } + + res.status(200).json({ ...newBody, proxy: body.proxy }); +}; + +/** + * Transforms a model response from the Anthropic API to match those from the + * OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This + * is only used for non-streaming requests as streaming requests are handled + * on-the-fly. + */ +function transformAwsTextResponseToOpenAI( + awsBody: Record, + req: Request +): Record { + const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0); + return { + id: "aws-" + v4(), + object: "chat.completion", + created: Date.now(), + model: req.body.model, + usage: { + prompt_tokens: req.promptTokens, + completion_tokens: req.outputTokens, + total_tokens: totalTokens, + }, + choices: [ + { + message: { + role: "assistant", + content: awsBody.completion?.trim(), + }, + finish_reason: awsBody.stop_reason, + index: 0, + }, + ], + }; +} + +const awsClaudeProxy = createQueueMiddleware({ + beforeProxy: signAwsRequest, + proxyMiddleware: createProxyMiddleware({ + target: "bad-target-will-be-rewritten", + router: ({ signedRequest }) => { + if (!signedRequest) throw new Error("Must sign request before proxying"); + return `${signedRequest.protocol}//${signedRequest.hostname}`; + }, + changeOrigin: true, + selfHandleResponse: true, + logger, + on: { + proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), + proxyRes: createOnProxyResHandler([awsResponseHandler]), + error: handleProxyError, + }, + }), +}); + +const nativeTextPreprocessor = createPreprocessorMiddleware( + { inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +const textToChatPreprocessor = createPreprocessorMiddleware( + { inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +/** + * Routes text completion prompts to aws anthropic-chat if they need translation + * (claude-3 based models do not support the old text completion endpoint). + */ +const preprocessAwsTextRequest: RequestHandler = (req, res, next) => { + if (req.body.model?.includes("claude-3")) { + textToChatPreprocessor(req, res, next); + } else { + nativeTextPreprocessor(req, res, next); + } +}; + +const oaiToAwsTextPreprocessor = createPreprocessorMiddleware( + { inApi: "openai", outApi: "anthropic-text", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +const oaiToAwsChatPreprocessor = createPreprocessorMiddleware( + { inApi: "openai", outApi: "anthropic-chat", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +/** + * Routes an OpenAI prompt to either the legacy Claude text completion endpoint + * or the new Claude chat completion endpoint, based on the requested model. + */ +const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { + if (req.body.model?.includes("claude-3")) { + oaiToAwsChatPreprocessor(req, res, next); + } else { + oaiToAwsTextPreprocessor(req, res, next); + } +}; + +const awsClaudeRouter = Router(); +// Native(ish) Anthropic text completion endpoint. +awsClaudeRouter.post( + "/v1/complete", + ipLimiter, + preprocessAwsTextRequest, + awsClaudeProxy +); +// Native Anthropic chat completion endpoint. +awsClaudeRouter.post( + "/v1/messages", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" }, + { afterTransform: [maybeReassignModel] } + ), + awsClaudeProxy +); + +// OpenAI-to-AWS Anthropic compatibility endpoint. +awsClaudeRouter.post( + "/v1/chat/completions", + ipLimiter, + preprocessOpenAICompatRequest, + awsClaudeProxy +); + +/** + * Tries to deal with: + * - frontends sending AWS model names even when they want to use the OpenAI- + * compatible endpoint + * - frontends sending Anthropic model names that AWS doesn't recognize + * - frontends sending OpenAI model names because they expect the proxy to + * translate them + * + * If client sends AWS model ID it will be used verbatim. Otherwise, various + * strategies are used to try to map a non-AWS model name to AWS model ID. + */ +function maybeReassignModel(req: Request) { + const model = req.body.model; + + // If it looks like an AWS model, use it as-is + if (model.includes("anthropic.claude")) { + return; + } + + // Anthropic model names can look like: + // - claude-v1 + // - claude-2.1 + // - claude-3-5-sonnet-20240620-v1:0 + const pattern = + /^(claude-)?(instant-)?(v)?(\d+)([.-](\d))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i; + const match = model.match(pattern); + + // If there's no match, fallback to Claude v2 as it is most likely to be + // available on AWS. + if (!match) { + req.body.model = `anthropic.claude-v2:1`; + return; + } + + const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match; + + if (instant) { + req.body.model = "anthropic.claude-instant-v1"; + return; + } + + const ver = minor ? `${major}.${minor}` : major; + switch (ver) { + case "1": + case "1.0": + req.body.model = "anthropic.claude-v1"; + return; + case "2": + case "2.0": + req.body.model = "anthropic.claude-v2"; + return; + case "3": + case "3.0": + if (name.includes("opus")) { + req.body.model = "anthropic.claude-3-opus-20240229-v1:0"; + } else if (name.includes("haiku")) { + req.body.model = "anthropic.claude-3-haiku-20240307-v1:0"; + } else { + req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0"; + } + return; + case "3.5": + req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + return; + } + + // Fallback to Claude 2.1 + req.body.model = `anthropic.claude-v2:1`; + return; +} + +export const awsClaude = awsClaudeRouter; diff --git a/src/proxy/aws-mistral.ts b/src/proxy/aws-mistral.ts new file mode 100644 index 0000000..446658b --- /dev/null +++ b/src/proxy/aws-mistral.ts @@ -0,0 +1,110 @@ +import { Request } from "express"; +import { + createOnProxyResHandler, + ProxyResHandlerWithBody, +} from "./middleware/response"; +import { createQueueMiddleware } from "./queue"; +import { + createOnProxyReqHandler, + createPreprocessorMiddleware, + finalizeSignedRequest, + signAwsRequest, +} from "./middleware/request"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { logger } from "../logger"; +import { handleProxyError } from "./middleware/common"; +import { Router } from "express"; +import { ipLimiter } from "./rate-limit"; +import { detectMistralInputApi, transformMistralTextToMistralChat } from "./mistral-ai"; + +const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + let newBody = body; + if (req.inboundApi === "mistral-ai" && req.outboundApi === "mistral-text") { + newBody = transformMistralTextToMistralChat(body); + } + // AWS does not always confirm the model in the response, so we have to add it + if (!newBody.model && req.body.model) { + newBody.model = req.body.model; + } + + res.status(200).json({ ...newBody, proxy: body.proxy }); +}; + +const awsMistralProxy = createQueueMiddleware({ + beforeProxy: signAwsRequest, + proxyMiddleware: createProxyMiddleware({ + target: "bad-target-will-be-rewritten", + router: ({ signedRequest }) => { + if (!signedRequest) throw new Error("Must sign request before proxying"); + return `${signedRequest.protocol}//${signedRequest.hostname}`; + }, + changeOrigin: true, + selfHandleResponse: true, + logger, + on: { + proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), + proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]), + error: handleProxyError, + }, + }), +}); + +function maybeReassignModel(req: Request) { + const model = req.body.model; + + // If it looks like an AWS model, use it as-is + if (model.startsWith("mistral.")) { + return; + } + // Mistral 7B Instruct + else if (model.includes("7b")) { + req.body.model = "mistral.mistral-7b-instruct-v0:2"; + } + // Mistral 8x7B Instruct + else if (model.includes("8x7b")) { + req.body.model = "mistral.mixtral-8x7b-instruct-v0:1"; + } + // Mistral Large (Feb 2024) + else if (model.includes("large-2402")) { + req.body.model = "mistral.mistral-large-2402-v1:0"; + } + // Mistral Large 2 (July 2024) + else if (model.includes("large")) { + req.body.model = "mistral.mistral-large-2407-v1:0"; + } + // Mistral Small (Feb 2024) + else if (model.includes("small")) { + req.body.model = "mistral.mistral-small-2402-v1:0"; + } else { + throw new Error( + `Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock` + ); + } +} + +const nativeMistralChatPreprocessor = createPreprocessorMiddleware( + { inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" }, + { + beforeTransform: [detectMistralInputApi], + afterTransform: [maybeReassignModel], + } +); + +const awsMistralRouter = Router(); +awsMistralRouter.post( + "/v1/chat/completions", + ipLimiter, + nativeMistralChatPreprocessor, + awsMistralProxy +); + +export const awsMistral = awsMistralRouter; diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index c02ff0e..48c663e 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -1,337 +1,75 @@ -import { Request, RequestHandler, Response, Router } from "express"; -import { createProxyMiddleware } from "http-proxy-middleware"; -import { v4 } from "uuid"; +/* Shared code between AWS Claude and AWS Mistral endpoints. */ + +import { Request, Response, Router } from "express"; import { config } from "../config"; -import { logger } from "../logger"; -import { createQueueMiddleware } from "./queue"; -import { ipLimiter } from "./rate-limit"; -import { handleProxyError } from "./middleware/common"; -import { - createPreprocessorMiddleware, - signAwsRequest, - finalizeSignedRequest, - createOnProxyReqHandler, -} from "./middleware/request"; -import { - ProxyResHandlerWithBody, - createOnProxyResHandler, -} from "./middleware/response"; -import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic"; -import { sendErrorToClient } from "./middleware/response/error-generator"; +import { addV1 } from "./add-v1"; +import { awsClaude } from "./aws-claude"; +import { awsMistral } from "./aws-mistral"; +import { AwsBedrockKey, keyPool } from "../shared/key-management"; -const LATEST_AWS_V2_MINOR_VERSION = "1"; - -let modelsCache: any = null; -let modelsCacheTime = 0; - -const getModelsResponse = () => { - if (new Date().getTime() - modelsCacheTime < 1000 * 60) { - return modelsCache; - } +const awsRouter = Router(); +awsRouter.get(["/:vendor?/v1/models", "/:vendor?/models"], handleModelsRequest); +awsRouter.use("/claude", addV1, awsClaude); +awsRouter.use("/mistral", addV1, awsMistral); +const MODELS_CACHE_TTL = 10000; +let modelsCache: Record = {}; +let modelsCacheTime: Record = {}; +function handleModelsRequest(req: Request, res: Response) { if (!config.awsCredentials) return { object: "list", data: [] }; + const vendor = req.params.vendor?.length + ? req.params.vendor === "claude" + ? "anthropic" + : req.params.vendor + : "all"; + + const cacheTime = modelsCacheTime[vendor] || 0; + if (new Date().getTime() - cacheTime < MODELS_CACHE_TTL) { + return res.json(modelsCache[vendor]); + } + + const availableModelIds = new Set(); + for (const key of keyPool.list()) { + if (key.isDisabled || key.service !== "aws") continue; + (key as AwsBedrockKey).modelIds.forEach((id) => availableModelIds.add(id)); + } + // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html - const variants = [ + const models = [ "anthropic.claude-v2", "anthropic.claude-v2:1", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-5-sonnet-20240620-v1:0", "anthropic.claude-3-opus-20240229-v1:0", - ]; - - const models = variants.map((id) => ({ - id, - object: "model", - created: new Date().getTime(), - owned_by: "anthropic", - permission: [], - root: "claude", - parent: null, - })); - - modelsCache = { object: "list", data: models }; - modelsCacheTime = new Date().getTime(); - - return modelsCache; -}; - -const handleModelRequest: RequestHandler = (_req, res) => { - res.status(200).json(getModelsResponse()); -}; - -/** Only used for non-streaming requests. */ -const awsResponseHandler: ProxyResHandlerWithBody = async ( - _proxyRes, - req, - res, - body -) => { - if (typeof body !== "object") { - throw new Error("Expected body to be an object"); - } - - let newBody = body; - switch (`${req.inboundApi}<-${req.outboundApi}`) { - case "openai<-anthropic-text": - req.log.info("Transforming Anthropic Text back to OpenAI format"); - newBody = transformAwsTextResponseToOpenAI(body, req); - break; - case "openai<-anthropic-chat": - req.log.info("Transforming AWS Anthropic Chat back to OpenAI format"); - newBody = transformAnthropicChatResponseToOpenAI(body); - break; - case "anthropic-text<-anthropic-chat": - req.log.info("Transforming AWS Anthropic Chat back to Text format"); - newBody = transformAnthropicChatResponseToAnthropicText(body); - break; - } - - // AWS does not always confirm the model in the response, so we have to add it - if (!newBody.model && req.body.model) { - newBody.model = req.body.model; - } - - res.status(200).json({ ...newBody, proxy: body.proxy }); -}; - -/** - * Transforms a model response from the Anthropic API to match those from the - * OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This - * is only used for non-streaming requests as streaming requests are handled - * on-the-fly. - */ -function transformAwsTextResponseToOpenAI( - awsBody: Record, - req: Request -): Record { - const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0); - return { - id: "aws-" + v4(), - object: "chat.completion", - created: Date.now(), - model: req.body.model, - usage: { - prompt_tokens: req.promptTokens, - completion_tokens: req.outputTokens, - total_tokens: totalTokens, - }, - choices: [ - { - message: { - role: "assistant", - content: awsBody.completion?.trim(), - }, - finish_reason: awsBody.stop_reason, - index: 0, - }, - ], - }; -} - -const awsProxy = createQueueMiddleware({ - beforeProxy: signAwsRequest, - proxyMiddleware: createProxyMiddleware({ - target: "bad-target-will-be-rewritten", - router: ({ signedRequest }) => { - if (!signedRequest) throw new Error("Must sign request before proxying"); - return `${signedRequest.protocol}//${signedRequest.hostname}`; - }, - changeOrigin: true, - selfHandleResponse: true, - logger, - on: { - proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), - proxyRes: createOnProxyResHandler([awsResponseHandler]), - error: handleProxyError, - }, - }), -}); - -const nativeTextPreprocessor = createPreprocessorMiddleware( - { inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" }, - { afterTransform: [maybeReassignModel] } -); - -const textToChatPreprocessor = createPreprocessorMiddleware( - { inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" }, - { afterTransform: [maybeReassignModel] } -); - -/** - * Routes text completion prompts to aws anthropic-chat if they need translation - * (claude-3 based models do not support the old text completion endpoint). - */ -const preprocessAwsTextRequest: RequestHandler = (req, res, next) => { - if (req.body.model?.includes("claude-3")) { - textToChatPreprocessor(req, res, next); - } else { - nativeTextPreprocessor(req, res, next); - } -}; - -const oaiToAwsTextPreprocessor = createPreprocessorMiddleware( - { inApi: "openai", outApi: "anthropic-text", service: "aws" }, - { afterTransform: [maybeReassignModel] } -); - -const oaiToAwsChatPreprocessor = createPreprocessorMiddleware( - { inApi: "openai", outApi: "anthropic-chat", service: "aws" }, - { afterTransform: [maybeReassignModel] } -); - -/** - * Routes an OpenAI prompt to either the legacy Claude text completion endpoint - * or the new Claude chat completion endpoint, based on the requested model. - */ -const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { - if (req.body.model?.includes("claude-3")) { - oaiToAwsChatPreprocessor(req, res, next); - } else { - oaiToAwsTextPreprocessor(req, res, next); - } -}; - -const awsRouter = Router(); -awsRouter.get("/v1/models", handleModelRequest); -// Native(ish) Anthropic text completion endpoint. -awsRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy); -// Native Anthropic chat completion endpoint. -awsRouter.post( - "/v1/messages", - ipLimiter, - createPreprocessorMiddleware( - { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" }, - { afterTransform: [maybeReassignModel] } - ), - awsProxy -); -// Temporary force-Claude3 endpoint -awsRouter.post( - "/v1/sonnet/:action(complete|messages)", - ipLimiter, - handleCompatibilityRequest, - createPreprocessorMiddleware({ - inApi: "anthropic-text", - outApi: "anthropic-chat", - service: "aws", - }), - awsProxy -); - -// OpenAI-to-AWS Anthropic compatibility endpoint. -awsRouter.post( - "/v1/chat/completions", - ipLimiter, - preprocessOpenAICompatRequest, - awsProxy -); - -/** - * Tries to deal with: - * - frontends sending AWS model names even when they want to use the OpenAI- - * compatible endpoint - * - frontends sending Anthropic model names that AWS doesn't recognize - * - frontends sending OpenAI model names because they expect the proxy to - * translate them - * - * If client sends AWS model ID it will be used verbatim. Otherwise, various - * strategies are used to try to map a non-AWS model name to AWS model ID. - */ -function maybeReassignModel(req: Request) { - const model = req.body.model; - - // If it looks like an AWS model, use it as-is - if (model.includes("anthropic.claude")) { - return; - } - - // Anthropic model names can look like: - // - claude-v1 - // - claude-2.1 - // - claude-3-5-sonnet-20240620-v1:0 - const pattern = - /^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i; - const match = model.match(pattern); - - // If there's no match, fallback to Claude v2 as it is most likely to be - // available on AWS. - if (!match) { - req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`; - return; - } - - const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match; - - if (instant) { - req.body.model = "anthropic.claude-instant-v1"; - return; - } - - const ver = minor ? `${major}.${minor}` : major; - switch (ver) { - case "1": - case "1.0": - req.body.model = "anthropic.claude-v1"; - return; - case "2": - case "2.0": - req.body.model = "anthropic.claude-v2"; - return; - case "3": - case "3.0": - if (name.includes("opus")) { - req.body.model = "anthropic.claude-3-opus-20240229-v1:0"; - } else if (name.includes("haiku")) { - req.body.model = "anthropic.claude-3-haiku-20240307-v1:0"; - } else { - req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0"; - } - return; - case "3.5": - req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; - return; - } - - // Fallback to Claude 2.1 - req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`; - return; -} - -export function handleCompatibilityRequest( - req: Request, - res: Response, - next: any -) { - const action = req.params.action; - const alreadyInChatFormat = Boolean(req.body.messages); - const compatModel = "anthropic.claude-3-5-sonnet-20240620-v1:0"; - req.log.info( - { inputModel: req.body.model, compatModel, alreadyInChatFormat }, - "Handling AWS compatibility request" - ); - - if (action === "messages" || alreadyInChatFormat) { - return sendErrorToClient({ - req, - res, - options: { - title: "Unnecessary usage of compatibility endpoint", - message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/aws/claude\` proxy endpoint instead.`, - format: "unknown", - statusCode: 400, - reqId: req.id, - obj: { - requested_endpoint: "/aws/claude/sonnet", - correct_endpoint: "/aws/claude", - }, - }, + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", + "mistral.mistral-large-2407-v1:0", + "mistral.mistral-small-2402-v1:0", + ] + .filter((id) => availableModelIds.has(id)) + .map((id) => { + const vendor = id.match(/^(.*)\./)?.[1]; + return { + id, + object: "model", + created: new Date().getTime(), + owned_by: vendor, + permission: [], + root: vendor, + parent: null, + }; }); - } - req.body.model = compatModel; - next(); + modelsCache[vendor] = { + object: "list", + data: models.filter((m) => vendor === "all" || m.root === vendor), + }; + modelsCacheTime[vendor] = new Date().getTime(); + + return res.json(modelsCache[vendor]); } export const aws = awsRouter; diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 0274522..2f815e4 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -221,9 +221,12 @@ export function getCompletionFromBody(req: Request, body: Record) { switch (format) { case "openai": case "mistral-ai": - // Can be null if the model wants to invoke tools rather than return a - // completion. - return body.choices[0].message.content || ""; + // Few possible values: + // - choices[0].message.content + // - choices[0].message with no content if model is invoking a tool + return body.choices?.[0]?.message?.content || ""; + case "mistral-text": + return body.outputs?.[0]?.text || ""; case "openai-text": return body.choices[0].text; case "anthropic-chat": @@ -260,22 +263,22 @@ export function getCompletionFromBody(req: Request, body: Record) { } } -export function getModelFromBody(req: Request, body: Record) { +export function getModelFromBody(req: Request, resBody: Record) { const format = req.outboundApi; switch (format) { case "openai": case "openai-text": + return resBody.model; case "mistral-ai": - return body.model; + case "mistral-text": case "openai-image": + case "google-ai": + // These formats don't have a model in the response body. return req.body.model; case "anthropic-chat": case "anthropic-text": // Anthropic confirms the model in the response, but AWS Claude doesn't. - return body.model || req.body.model; - case "google-ai": - // Google doesn't confirm the model in the response. - return req.body.model; + return resBody.model || req.body.model; default: assertNever(format); } diff --git a/src/proxy/middleware/request/onproxyreq/add-key.ts b/src/proxy/middleware/request/onproxyreq/add-key.ts index 27b2dc3..160e8d0 100644 --- a/src/proxy/middleware/request/onproxyreq/add-key.ts +++ b/src/proxy/middleware/request/onproxyreq/add-key.ts @@ -38,7 +38,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => { // translation now reassigns the model earlier in the request pipeline. case "anthropic-text": case "anthropic-chat": - assignedKey = keyPool.get("claude-v1", service, needsMultimodal); + case "mistral-ai": + case "mistral-text": + case "google-ai": + assignedKey = keyPool.get(body.model, service); break; case "openai-text": assignedKey = keyPool.get("gpt-3.5-turbo-instruct", service); @@ -47,10 +50,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => { assignedKey = keyPool.get("dall-e-3", service); break; case "openai": - case "google-ai": - case "mistral-ai": throw new Error( - `add-key should not be called for outbound API ${outboundApi}` + `Outbound API ${outboundApi} is not supported for ${inboundApi}` ); default: assertNever(outboundApi); diff --git a/src/proxy/middleware/request/preprocessor-factory.ts b/src/proxy/middleware/request/preprocessor-factory.ts index 846bed1..3fc1a54 100644 --- a/src/proxy/middleware/request/preprocessor-factory.ts +++ b/src/proxy/middleware/request/preprocessor-factory.ts @@ -86,7 +86,7 @@ async function executePreprocessors( const msg = error?.issues ?.map((issue: ZodIssue) => issue.message) .join("; "); - req.log.info(msg, "Prompt validation failed."); + req.log.warn({ issues: msg }, "Prompt validation failed."); } else { req.log.error(error, "Error while executing request preprocessor"); } diff --git a/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts b/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts index 130bedf..ec00e55 100644 --- a/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts +++ b/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts @@ -2,7 +2,6 @@ import { RequestPreprocessor } from "../index"; import { countTokens } from "../../../../shared/tokenization"; import { assertNever } from "../../../../shared/utils"; import { - AnthropicChatMessage, GoogleAIChatMessage, MistralAIChatMessage, OpenAIChatMessage, @@ -50,9 +49,11 @@ export const countPromptTokens: RequestPreprocessor = async (req) => { result = await countTokens({ req, prompt, service }); break; } - case "mistral-ai": { + case "mistral-ai": + case "mistral-text": { req.outputTokens = req.body.max_tokens; - const prompt: MistralAIChatMessage[] = req.body.messages; + const prompt: string | MistralAIChatMessage[] = + req.body.messages ?? req.body.prompt; result = await countTokens({ req, prompt, service }); break; } diff --git a/src/proxy/middleware/request/preprocessors/language-filter.ts b/src/proxy/middleware/request/preprocessors/language-filter.ts index 9610cb4..345aa75 100644 --- a/src/proxy/middleware/request/preprocessors/language-filter.ts +++ b/src/proxy/middleware/request/preprocessors/language-filter.ts @@ -56,8 +56,6 @@ function getPromptFromRequest(req: Request) { switch (service) { case "anthropic-chat": return flattenAnthropicMessages(body.messages); - case "anthropic-text": - return body.prompt; case "openai": case "mistral-ai": return body.messages @@ -72,8 +70,10 @@ function getPromptFromRequest(req: Request) { return `${msg.role}: ${text}`; }) .join("\n\n"); + case "anthropic-text": case "openai-text": case "openai-image": + case "mistral-text": return body.prompt; case "google-ai": return body.prompt.text; diff --git a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts index b83ba14..eb8ec68 100644 --- a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts +++ b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts @@ -1,4 +1,4 @@ -import express from "express"; +import express, { Request } from "express"; import { Sha256 } from "@aws-crypto/sha256-js"; import { SignatureV4 } from "@smithy/signature-v4"; import { HttpRequest } from "@smithy/protocol-http"; @@ -8,6 +8,10 @@ import { } from "../../../../shared/api-schemas"; import { keyPool } from "../../../../shared/key-management"; import { RequestPreprocessor } from "../index"; +import { + AWSMistralV1ChatCompletionsSchema, + AWSMistralV1TextCompletionsSchema, +} from "../../../../shared/api-schemas/mistral-ai"; const AMZ_HOST = process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com"; @@ -29,38 +33,6 @@ export const signAwsRequest: RequestPreprocessor = async (req) => { req.body.prompt = preamble + req.body.prompt; } - // AWS uses mostly the same parameters as Anthropic, with a few removed params - // and much stricter validation on unused parameters. Rather than treating it - // as a separate schema we will use the anthropic ones and strip the unused - // parameters. - // TODO: This should happen in transform-outbound-payload.ts - let strippedParams: Record; - if (req.outboundApi === "anthropic-chat") { - strippedParams = AnthropicV1MessagesSchema.pick({ - messages: true, - system: true, - max_tokens: true, - stop_sequences: true, - temperature: true, - top_k: true, - top_p: true, - }) - .strip() - .parse(req.body); - strippedParams.anthropic_version = "bedrock-2023-05-31"; - } else { - strippedParams = AnthropicV1TextSchema.pick({ - prompt: true, - max_tokens_to_sample: true, - stop_sequences: true, - temperature: true, - top_k: true, - top_p: true, - }) - .strip() - .parse(req.body); - } - const credential = getCredentialParts(req); const host = AMZ_HOST.replace("%REGION%", credential.region); // AWS only uses 2023-06-01 and does not actually check this header, but we @@ -78,7 +50,7 @@ export const signAwsRequest: RequestPreprocessor = async (req) => { ["Host"]: host, ["content-type"]: "application/json", }, - body: JSON.stringify(strippedParams), + body: JSON.stringify(applyAwsStrictValidation(req)), }); if (stream) { @@ -128,3 +100,48 @@ async function sign(request: HttpRequest, credential: Credential) { return signer.sign(request); } + +function applyAwsStrictValidation(req: Request): unknown { + // AWS uses vendor API formats but imposes additional (more strict) validation + // rules, namely that extraneous parameters are not allowed. We will validate + // using the vendor's zod schema but apply `.strip` to ensure that any + // extraneous parameters are removed. + let strippedParams: Record = {}; + switch (req.outboundApi) { + case "anthropic-text": + strippedParams = AnthropicV1TextSchema.pick({ + prompt: true, + max_tokens_to_sample: true, + stop_sequences: true, + temperature: true, + top_k: true, + top_p: true, + }) + .strip() + .parse(req.body); + break; + case "anthropic-chat": + strippedParams = AnthropicV1MessagesSchema.pick({ + messages: true, + system: true, + max_tokens: true, + stop_sequences: true, + temperature: true, + top_k: true, + top_p: true, + }) + .strip() + .parse(req.body); + strippedParams.anthropic_version = "bedrock-2023-05-31"; + break; + case "mistral-ai": + strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body); + break; + case "mistral-text": + strippedParams = AWSMistralV1TextCompletionsSchema.parse(req.body); + break; + default: + throw new Error("Unexpected outbound API for AWS."); + } + return strippedParams; +} diff --git a/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts index 1186367..42da63e 100644 --- a/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts @@ -1,3 +1,4 @@ +import { Request } from "express"; import { API_REQUEST_VALIDATORS, API_REQUEST_TRANSFORMERS, @@ -12,29 +13,23 @@ import { RequestPreprocessor } from "../index"; /** Transforms an incoming request body to one that matches the target API. */ export const transformOutboundPayload: RequestPreprocessor = async (req) => { - const sameService = req.inboundApi === req.outboundApi; const alreadyTransformed = req.retryCount > 0; const notTransformable = !isTextGenerationRequest(req) && !isImageGenerationRequest(req); if (alreadyTransformed || notTransformable) return; - // TODO: this should be an APIFormatTransformer - if (req.inboundApi === "mistral-ai") { - const messages = req.body.messages; - req.body.messages = fixMistralPrompt(messages); - req.log.info( - { old: messages.length, new: req.body.messages.length }, - "Fixed Mistral prompt" - ); - } + applyMistralPromptFixes(req); - if (sameService) { + // Native prompts are those which were already provided by the client in the + // target API format. We don't need to transform them. + const isNativePrompt = req.inboundApi === req.outboundApi; + if (isNativePrompt) { const result = API_REQUEST_VALIDATORS[req.inboundApi].safeParse(req.body); if (!result.success) { req.log.warn( { issues: result.error.issues, body: req.body }, - "Request validation failed" + "Native prompt request validation failed." ); throw result.error; } @@ -42,11 +37,12 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { return; } + // Prompt requires translation from one API format to another. const transformation = `${req.inboundApi}->${req.outboundApi}` as const; const transFn = API_REQUEST_TRANSFORMERS[transformation]; if (transFn) { - req.log.info({ transformation }, "Transforming request"); + req.log.info({ transformation }, "Transforming request..."); req.body = await transFn(req); return; } @@ -55,3 +51,37 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { `${transformation} proxying is not supported. Make sure your client is configured to send requests in the correct format and to the correct endpoint.` ); }; + +// handles weird cases that don't fit into our abstractions +function applyMistralPromptFixes(req: Request): void { + if (req.inboundApi === "mistral-ai") { + // Mistral Chat is very similar to OpenAI but not identical and many clients + // don't properly handle the differences. We will try to validate the + // mistral prompt and try to fix it if it fails. It will be re-validated + // after this function returns. + const result = API_REQUEST_VALIDATORS["mistral-ai"].safeParse(req.body); + if (!result.success) { + const messages = req.body.messages; + req.body.messages = fixMistralPrompt(messages); + req.log.info( + { old: messages.length, new: req.body.messages.length }, + "Applied Mistral chat prompt fixes." + ); + } + + // If the prompt relies on `prefix: true` for the last message, we need to + // convert it to a text completions request because Mistral support for + // this feature is limited (and completely broken on AWS Mistral). + const { messages } = req.body; + const lastMessage = messages && messages[messages.length - 1]; + if (lastMessage && lastMessage.role === "assistant") { + // enable prefix if client forgot, otherwise the template will insert an + // eos token which is very unlikely to be what the client wants. + lastMessage.prefix = true; + req.outboundApi = "mistral-text"; + req.log.info( + "Native Mistral chat prompt relies on assistant message prefix. Converting to text completions request." + ); + } + } +} diff --git a/src/proxy/middleware/request/preprocessors/validate-context-size.ts b/src/proxy/middleware/request/preprocessors/validate-context-size.ts index c786b06..1cf9854 100644 --- a/src/proxy/middleware/request/preprocessors/validate-context-size.ts +++ b/src/proxy/middleware/request/preprocessors/validate-context-size.ts @@ -38,6 +38,7 @@ export const validateContextSize: RequestPreprocessor = async (req) => { proxyMax = GOOGLE_AI_MAX_CONTEXT; break; case "mistral-ai": + case "mistral-text": proxyMax = MISTRAL_AI_MAX_CONTENT; break; case "openai-image": diff --git a/src/proxy/middleware/request/preprocessors/validate-vision.ts b/src/proxy/middleware/request/preprocessors/validate-vision.ts index b72f6fd..5940222 100644 --- a/src/proxy/middleware/request/preprocessors/validate-vision.ts +++ b/src/proxy/middleware/request/preprocessors/validate-vision.ts @@ -28,6 +28,7 @@ export const validateVision: RequestPreprocessor = async (req) => { case "anthropic-text": case "google-ai": case "mistral-ai": + case "mistral-text": case "openai-image": case "openai-text": return; diff --git a/src/proxy/middleware/response/error-generator.ts b/src/proxy/middleware/response/error-generator.ts index 4c39445..f2bfdad 100644 --- a/src/proxy/middleware/response/error-generator.ts +++ b/src/proxy/middleware/response/error-generator.ts @@ -189,6 +189,11 @@ export function buildSpoofedCompletion({ }, ], }; + case "mistral-text": + return { + outputs: [{ text: content, stop_reason: title }], + model, + } case "openai-text": return { id: "error-" + id, @@ -267,6 +272,11 @@ export function buildSpoofedSSE({ choices: [{ delta: { content }, index: 0, finish_reason: title }], }; break; + case "mistral-text": + event = { + outputs: [{ text: content, stop_reason: title }], + }; + break; case "openai-text": event = { id: "cmpl-" + id, diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index f84e263..a15eb51 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -22,18 +22,19 @@ import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; const pipelineAsync = promisify(pipeline); /** - * `handleStreamedResponse` consumes and transforms a streamed response from the - * upstream service, forwarding events to the client in their requested format. + * `handleStreamedResponse` consumes a streamed response from the upstream API, + * decodes chunk-by-chunk into a stream of events, transforms those events into + * the client's requested format, and forwards the result to the client. + * * After the entire stream has been consumed, it resolves with the full response * body so that subsequent middleware in the chain can process it as if it were - * a non-streaming response. + * a non-streaming response (to count output tokens, track usage, etc). * - * In the event of an error, the request's streaming flag is unset and the non- - * streaming response handler is called instead. - * - * If the error is retryable, that handler will re-enqueue the request and also - * reset the streaming flag. Unfortunately the streaming flag is set and unset - * in multiple places, so it's hard to keep track of. + * In the event of an error, the request's streaming flag is unset and the + * request is bounced back to the non-streaming response handler. If the error + * is retryable, that handler will re-enqueue the request and also reset the + * streaming flag. Unfortunately the streaming flag is set and unset in multiple + * places, so it's hard to keep track of. */ export const handleStreamedResponse: RawResponseBodyHandler = async ( proxyRes, @@ -70,13 +71,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( logger: req.log, }; - // Decoder turns the raw response stream into a stream of events in some - // format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc). + // While the request is streaming, aggregator collects all events so that we + // can compile them into a single response object and publish that to the + // remaining middleware. Because we have an OpenAI transformer for every + // supported format, EventAggregator always consumes OpenAI events so that we + // only have to write one aggregator (OpenAI input) for each output format. + const aggregator = new EventAggregator(req); + + // Decoder reads from the raw response buffer and produces a stream of + // discrete events in some format (text/event-stream, vnd.amazon.event-stream, + // streaming JSON, etc). const decoder = getDecoder({ ...streamOptions, input: proxyRes }); - // Adapter transforms the decoded events into server-sent events. + // Adapter consumes the decoded events and produces server-sent events so we + // have a standard event format for the client and to translate between API + // message formats. const adapter = new SSEStreamAdapter(streamOptions); - // Aggregator compiles all events into a single response object. - const aggregator = new EventAggregator({ format: req.outboundApi }); // Transformer converts server-sent events from one vendor's API message // format to another. const transformer = new SSEMessageTransformer({ diff --git a/src/proxy/middleware/response/log-prompt.ts b/src/proxy/middleware/response/log-prompt.ts index cfa2aa0..7983a1c 100644 --- a/src/proxy/middleware/response/log-prompt.ts +++ b/src/proxy/middleware/response/log-prompt.ts @@ -11,7 +11,8 @@ import { ProxyResHandlerWithBody } from "."; import { assertNever } from "../../../shared/utils"; import { AnthropicChatMessage, - flattenAnthropicMessages, GoogleAIChatMessage, + flattenAnthropicMessages, + GoogleAIChatMessage, MistralAIChatMessage, OpenAIChatMessage, } from "../../../shared/api-schemas"; @@ -76,6 +77,8 @@ const getPromptForRequest = ( case "anthropic-chat": return { system: req.body.system, messages: req.body.messages }; case "openai-text": + case "anthropic-text": + case "mistral-text": return req.body.prompt; case "openai-image": return { @@ -85,8 +88,6 @@ const getPromptForRequest = ( quality: req.body.quality, revisedPrompt: responseBody.data[0].revised_prompt, }; - case "anthropic-text": - return req.body.prompt; case "google-ai": return { contents: req.body.contents }; default: @@ -113,9 +114,7 @@ const flattenMessages = ( if (isGoogleAIChatPrompt(val)) { return val.contents .map(({ parts, role }) => { - const text = parts - .map((p) => p.text) - .join("\n"); + const text = parts.map((p) => p.text).join("\n"); return `${role}: ${text}`; }) .join("\n"); @@ -143,11 +142,7 @@ const flattenMessages = ( function isGoogleAIChatPrompt( val: unknown ): val is { contents: GoogleAIChatMessage[] } { - return ( - typeof val === "object" && - val !== null && - "contents" in val - ); + return typeof val === "object" && val !== null && "contents" in val; } function isAnthropicChatPrompt( diff --git a/src/proxy/middleware/response/streaming/aggregators/mistral-chat.ts b/src/proxy/middleware/response/streaming/aggregators/mistral-chat.ts new file mode 100644 index 0000000..9ca6f32 --- /dev/null +++ b/src/proxy/middleware/response/streaming/aggregators/mistral-chat.ts @@ -0,0 +1,51 @@ +import { OpenAIChatCompletionStreamEvent } from "../index"; + +/* + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Genshin Impact is an action role-play" + }, + "stop_reason": "length" + } + ], + */ +export type MistralChatCompletionResponse = { + choices: { + index: number; + message: { role: string; content: string }; + finish_reason: string | null; + }[]; +}; + +/** + * Given a list of OpenAI chat completion events, compiles them into a single + * finalized Mistral chat completion response so that non-streaming middleware + * can operate on it as if it were a blocking response. + */ +export function mergeEventsForMistralChat( + events: OpenAIChatCompletionStreamEvent[] +): MistralChatCompletionResponse { + let merged: MistralChatCompletionResponse = { + choices: [ + { index: 0, message: { role: "", content: "" }, finish_reason: "" }, + ], + }; + merged = events.reduce((acc, event, i) => { + // The first event will only contain role assignment and response metadata + if (i === 0) { + acc.choices[0].message.role = event.choices[0].delta.role ?? "assistant"; + return acc; + } + + acc.choices[0].finish_reason = event.choices[0].finish_reason ?? ""; + if (event.choices[0].delta.content) { + acc.choices[0].message.content += event.choices[0].delta.content; + } + + return acc; + }, merged); + return merged; +} diff --git a/src/proxy/middleware/response/streaming/aggregators/mistral-text.ts b/src/proxy/middleware/response/streaming/aggregators/mistral-text.ts new file mode 100644 index 0000000..afe9e3d --- /dev/null +++ b/src/proxy/middleware/response/streaming/aggregators/mistral-text.ts @@ -0,0 +1,33 @@ +import { OpenAIChatCompletionStreamEvent } from "../index"; + +export type MistralTextCompletionResponse = { + outputs: { + text: string; + stop_reason: string | null; + }[]; +}; + +/** + * Given a list of OpenAI chat completion events, compiles them into a single + * finalized Mistral text completion response so that non-streaming middleware + * can operate on it as if it were a blocking response. + */ +export function mergeEventsForMistralText( + events: OpenAIChatCompletionStreamEvent[] +): MistralTextCompletionResponse { + let merged: MistralTextCompletionResponse = { + outputs: [{ text: "", stop_reason: "" }], + }; + merged = events.reduce((acc, event, i) => { + // The first event will only contain role assignment and response metadata + if (i === 0) { + return acc; + } + + acc.outputs[0].text += event.choices[0].delta.content ?? ""; + acc.outputs[0].stop_reason = event.choices[0].finish_reason ?? ""; + + return acc; + }, merged); + return merged; +} diff --git a/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts b/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts index b394142..fe06289 100644 --- a/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts +++ b/src/proxy/middleware/response/streaming/aws-event-stream-decoder.ts @@ -24,7 +24,7 @@ export function getAwsEventStreamDecoder(params: { if (eventType === "chunk") { result = input[eventType]; } else { - // AWS unmarshaller treats non-chunk (errors and exceptions) oddly. + // AWS unmarshaller treats non-chunk events (errors and exceptions) oddly. result = { [eventType]: input[eventType] } as any; } return result; diff --git a/src/proxy/middleware/response/streaming/event-aggregator.ts b/src/proxy/middleware/response/streaming/event-aggregator.ts index d80c738..dd856f0 100644 --- a/src/proxy/middleware/response/streaming/event-aggregator.ts +++ b/src/proxy/middleware/response/streaming/event-aggregator.ts @@ -1,3 +1,4 @@ +import express from "express"; import { APIFormat } from "../../../../shared/key-management"; import { assertNever } from "../../../../shared/utils"; import { @@ -6,8 +7,13 @@ import { mergeEventsForAnthropicText, mergeEventsForOpenAIChat, mergeEventsForOpenAIText, + mergeEventsForMistralChat, + mergeEventsForMistralText, AnthropicV2StreamEvent, OpenAIChatCompletionStreamEvent, + mistralAIToOpenAI, + MistralAIStreamEvent, + MistralChatCompletionEvent, } from "./index"; /** @@ -15,45 +21,74 @@ import { * compiles them into a single finalized response for downstream middleware. */ export class EventAggregator { - private readonly format: APIFormat; + private readonly model: string; + private readonly requestFormat: APIFormat; + private readonly responseFormat: APIFormat; private readonly events: OpenAIChatCompletionStreamEvent[]; - constructor({ format }: { format: APIFormat }) { + constructor({ body, inboundApi, outboundApi }: express.Request) { this.events = []; - this.format = format; + this.requestFormat = inboundApi; + this.responseFormat = outboundApi; + this.model = body.model; } - addEvent(event: OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent) { + addEvent( + event: + | OpenAIChatCompletionStreamEvent + | AnthropicV2StreamEvent + | MistralAIStreamEvent + ) { if (eventIsOpenAIEvent(event)) { this.events.push(event); } else { // horrible special case. previously all transformers' target format was // openai, so the event aggregator could conveniently assume all incoming // events were in openai format. - // now we have added anthropic-chat-to-text, so aggregator needs to know - // how to collapse events from two formats. - // because that is annoying, we will simply transform anthropic events to - // openai (even if the client didn't ask for openai) so we don't have to - // write aggregation logic for anthropic chat (which is also a troublesome - // stateful format). - const openAIEvent = anthropicV2ToOpenAI({ - data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`, - lastPosition: -1, - index: 0, - fallbackId: event.log_id || "event-aggregator-fallback", - fallbackModel: event.model || "claude-3-fallback", - }); - if (openAIEvent.event) { - this.events.push(openAIEvent.event); + // now we have added some transformers that convert between non-openai + // formats, so aggregator needs to know how to collapse for more than + // just openai. + // because writing aggregation logic for every possible output format is + // annoying, we will just transform any non-openai output events to openai + // format (even if the client did not request openai at all) so that we + // still only need to write aggregators for openai SSEs. + let openAIEvent: OpenAIChatCompletionStreamEvent | undefined; + switch (this.requestFormat) { + case "anthropic-text": + if (!eventIsAnthropicV2Event(event)) { + throw new Error(`Bad event for Anthropic V2 SSE aggregation`); + } + openAIEvent = anthropicV2ToOpenAI({ + data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`, + lastPosition: -1, + index: 0, + fallbackId: event.log_id || "fallback-" + Date.now(), + fallbackModel: event.model || this.model || "fallback-claude-3", + })?.event; + break; + case "mistral-ai": + if (!eventIsMistralChatEvent(event)) { + throw new Error(`Bad event for Mistral SSE aggregation`); + } + openAIEvent = mistralAIToOpenAI({ + data: `data: ${JSON.stringify(event)}\n\n`, + lastPosition: -1, + index: 0, + fallbackId: "fallback-" + Date.now(), + fallbackModel: this.model || "fallback-mistral", + })?.event; + break; + } + if (openAIEvent) { + this.events.push(openAIEvent); } } } getFinalResponse() { - switch (this.format) { + switch (this.responseFormat) { case "openai": - case "google-ai": - case "mistral-ai": + case "google-ai": // TODO: this is probably wrong now that we support native Google Makersuite prompts return mergeEventsForOpenAIChat(this.events); case "openai-text": return mergeEventsForOpenAIText(this.events); @@ -61,10 +96,16 @@ export class EventAggregator { return mergeEventsForAnthropicText(this.events); case "anthropic-chat": return mergeEventsForAnthropicChat(this.events); + case "mistral-ai": + return mergeEventsForMistralChat(this.events); + case "mistral-text": + return mergeEventsForMistralText(this.events); case "openai-image": - throw new Error(`SSE aggregation not supported for ${this.format}`); + throw new Error( + `SSE aggregation not supported for ${this.responseFormat}` + ); default: - assertNever(this.format); + assertNever(this.responseFormat); } } @@ -78,3 +119,13 @@ function eventIsOpenAIEvent( ): event is OpenAIChatCompletionStreamEvent { return event?.object === "chat.completion.chunk"; } + +function eventIsAnthropicV2Event(event: any): event is AnthropicV2StreamEvent { + return event?.completion; +} + +function eventIsMistralChatEvent( + event: any +): event is MistralChatCompletionEvent { + return event?.choices; +} diff --git a/src/proxy/middleware/response/streaming/index.ts b/src/proxy/middleware/response/streaming/index.ts index 402c233..5274fb0 100644 --- a/src/proxy/middleware/response/streaming/index.ts +++ b/src/proxy/middleware/response/streaming/index.ts @@ -7,6 +7,25 @@ export type SSEResponseTransformArgs> = { state?: S; }; +export type MistralChatCompletionEvent = { + choices: { + index: number; + message: { role: string; content: string }; + stop_reason: string | null; + }[]; +}; +export type MistralTextCompletionEvent = { + outputs: { text: string; stop_reason: string | null }[]; +}; +export type MistralAIStreamEvent = { + "amazon-bedrock-invocationMetrics"?: { + inputTokenCount: number; + outputTokenCount: number; + invocationLatency: number; + firstByteLatency: number; + }; +} & (MistralChatCompletionEvent | MistralTextCompletionEvent); + export type AnthropicV2StreamEvent = { log_id?: string; model?: string; @@ -41,8 +60,12 @@ export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai"; export { anthropicChatToAnthropicV2 } from "./transformers/anthropic-chat-to-anthropic-v2"; export { anthropicChatToOpenAI } from "./transformers/anthropic-chat-to-openai"; export { googleAIToOpenAI } from "./transformers/google-ai-to-openai"; +export { mistralAIToOpenAI } from "./transformers/mistral-ai-to-openai"; +export { mistralTextToMistralChat } from "./transformers/mistral-text-to-mistral-chat"; export { passthroughToOpenAI } from "./transformers/passthrough-to-openai"; export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat"; export { mergeEventsForOpenAIText } from "./aggregators/openai-text"; export { mergeEventsForAnthropicText } from "./aggregators/anthropic-text"; export { mergeEventsForAnthropicChat } from "./aggregators/anthropic-chat"; +export { mergeEventsForMistralChat } from "./aggregators/mistral-chat"; +export { mergeEventsForMistralText } from "./aggregators/mistral-text"; diff --git a/src/proxy/middleware/response/streaming/sse-message-transformer.ts b/src/proxy/middleware/response/streaming/sse-message-transformer.ts index 800b286..daf5c6a 100644 --- a/src/proxy/middleware/response/streaming/sse-message-transformer.ts +++ b/src/proxy/middleware/response/streaming/sse-message-transformer.ts @@ -11,8 +11,11 @@ import { googleAIToOpenAI, OpenAIChatCompletionStreamEvent, openAITextToOpenAIChat, + mistralAIToOpenAI, + mistralTextToMistralChat, passthroughToOpenAI, StreamingCompletionTransformer, + MistralChatCompletionEvent, } from "./index"; type SSEMessageTransformerOptions = TransformOptions & { @@ -35,7 +38,9 @@ export class SSEMessageTransformer extends Transform { private readonly inputFormat: APIFormat; private readonly transformFn: StreamingCompletionTransformer< // TODO: Refactor transformers to not assume only OpenAI events as output - OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent + | OpenAIChatCompletionStreamEvent + | AnthropicV2StreamEvent + | MistralChatCompletionEvent >; private readonly log; private readonly fallbackId: string; @@ -121,16 +126,17 @@ function eventIsOpenAIEvent( function getTransformer( responseApi: APIFormat, version?: string, - // There's only one case where we're not transforming back to OpenAI, which is - // Anthropic Chat response -> Anthropic Text request. This parameter is only - // used for that case. + // In most cases, we are transforming back to OpenAI. Some responses can be + // translated between two non-OpenAI formats, eg Anthropic Chat -> Anthropic + // Text, or Mistral Text -> Mistral Chat. requestApi: APIFormat = "openai" ): StreamingCompletionTransformer< - OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent + | OpenAIChatCompletionStreamEvent + | AnthropicV2StreamEvent + | MistralChatCompletionEvent > { switch (responseApi) { case "openai": - case "mistral-ai": return passthroughToOpenAI; case "openai-text": return openAITextToOpenAIChat; @@ -140,10 +146,16 @@ function getTransformer( : anthropicV2ToOpenAI; case "anthropic-chat": return requestApi === "anthropic-text" - ? anthropicChatToAnthropicV2 + ? anthropicChatToAnthropicV2 // User's legacy text prompt was converted to chat, and response must be converted back to text : anthropicChatToOpenAI; case "google-ai": return googleAIToOpenAI; + case "mistral-ai": + return mistralAIToOpenAI; + case "mistral-text": + return requestApi === "mistral-ai" + ? mistralTextToMistralChat // User's chat request was converted to text, and response must be converted back to chat + : mistralAIToOpenAI; case "openai-image": throw new Error(`SSE transformation not supported for ${responseApi}`); default: diff --git a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts index 7c54e57..f74bb9c 100644 --- a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts +++ b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts @@ -55,8 +55,10 @@ export class SSEStreamAdapter extends Transform { if ("completion" in eventObj) { return ["event: completion", `data: ${event}`].join(`\n`); - } else { + } else if (eventObj.type) { return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`); + } else { + return `data: ${event}`; } } // noinspection FallThroughInSwitchStatementJS -- non-JSON data is unexpected diff --git a/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts new file mode 100644 index 0000000..df34fba --- /dev/null +++ b/src/proxy/middleware/response/streaming/transformers/mistral-ai-to-openai.ts @@ -0,0 +1,76 @@ +import { logger } from "../../../../../logger"; +import { MistralAIStreamEvent, SSEResponseTransformArgs } from "../index"; +import { parseEvent, ServerSentEvent } from "../parse-sse"; + +const log = logger.child({ + module: "sse-transformer", + transformer: "mistral-ai-to-openai", +}); + +export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => { + const { data } = params; + + const rawEvent = parseEvent(data); + if (!rawEvent.data || rawEvent.data === "[DONE]") { + return { position: -1 }; + } + + const completionEvent = asCompletion(rawEvent); + if (!completionEvent) { + return { position: -1 }; + } + + if ("choices" in completionEvent) { + const newChatEvent = { + id: params.fallbackId, + object: "chat.completion.chunk" as const, + created: Date.now(), + model: params.fallbackModel, + choices: [ + { + index: completionEvent.choices[0].index, + delta: { content: completionEvent.choices[0].message.content }, + finish_reason: completionEvent.choices[0].stop_reason, + }, + ], + }; + return { position: -1, event: newChatEvent }; + } else if ("outputs" in completionEvent) { + const newTextEvent = { + id: params.fallbackId, + object: "chat.completion.chunk" as const, + created: Date.now(), + model: params.fallbackModel, + choices: [ + { + index: 0, + delta: { content: completionEvent.outputs[0].text }, + finish_reason: completionEvent.outputs[0].stop_reason, + }, + ], + }; + return { position: -1, event: newTextEvent }; + } + + // should never happen + return { position: -1 }; +}; + +function asCompletion(event: ServerSentEvent): MistralAIStreamEvent | null { + try { + const parsed = JSON.parse(event.data); + if ( + (Array.isArray(parsed.choices) && + parsed.choices[0].message !== undefined) || + (Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined) + ) { + return parsed; + } else { + // noinspection ExceptionCaughtLocallyJS + throw new Error("Missing required fields"); + } + } catch (error) { + log.warn({ error: error.stack, event }, "Received invalid data event"); + } + return null; +} diff --git a/src/proxy/middleware/response/streaming/transformers/mistral-text-to-mistral-chat.ts b/src/proxy/middleware/response/streaming/transformers/mistral-text-to-mistral-chat.ts new file mode 100644 index 0000000..f15dd0f --- /dev/null +++ b/src/proxy/middleware/response/streaming/transformers/mistral-text-to-mistral-chat.ts @@ -0,0 +1,63 @@ +import { + MistralChatCompletionEvent, + MistralTextCompletionEvent, + StreamingCompletionTransformer, +} from "../index"; +import { parseEvent, ServerSentEvent } from "../parse-sse"; +import { logger } from "../../../../../logger"; + +const log = logger.child({ + module: "sse-transformer", + transformer: "mistral-text-to-mistral-chat", +}); + +/** + * Transforms an incoming Mistral Text SSE to an equivalent Mistral Chat SSE. + * This is generally used when a client sends a Mistral Chat prompt, but we + * convert it to Mistral Text before sending it to the API to work around + * some bugs in Mistral/AWS prompt templating. In these cases we need to convert + * the response back to Mistral Chat. + */ +export const mistralTextToMistralChat: StreamingCompletionTransformer< + MistralChatCompletionEvent +> = (params) => { + const { data } = params; + + const rawEvent = parseEvent(data); + if (!rawEvent.data) { + return { position: -1 }; + } + + const textCompletion = asTextCompletion(rawEvent); + if (!textCompletion) { + return { position: -1 }; + } + + const chatEvent: MistralChatCompletionEvent = { + choices: [ + { + index: 0, + message: { role: "assistant", content: textCompletion.outputs[0].text }, + stop_reason: textCompletion.outputs[0].stop_reason, + }, + ], + }; + return { position: -1, event: chatEvent }; +}; + +function asTextCompletion( + event: ServerSentEvent +): MistralTextCompletionEvent | null { + try { + const parsed = JSON.parse(event.data); + if (Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined) { + return parsed; + } else { + // noinspection ExceptionCaughtLocallyJS + throw new Error("Missing required fields"); + } + } catch (error: any) { + log.warn({ error: error.stack, event }, "Received invalid data event"); + } + return null; +} diff --git a/src/proxy/mistral-ai.ts b/src/proxy/mistral-ai.ts index 4740a11..db477ee 100644 --- a/src/proxy/mistral-ai.ts +++ b/src/proxy/mistral-ai.ts @@ -1,4 +1,4 @@ -import { RequestHandler, Router } from "express"; +import express, { Request, RequestHandler, Router } from "express"; import { createProxyMiddleware } from "http-proxy-middleware"; import { config } from "../config"; import { keyPool } from "../shared/key-management"; @@ -61,7 +61,7 @@ export const KNOWN_MISTRAL_AI_MODELS = [ "mistral-medium-latest", "mistral-medium-2312", "mistral-tiny", - "mistral-tiny-2312" + "mistral-tiny-2312", ]; let modelsCache: any = null; @@ -108,9 +108,24 @@ const mistralAIResponseHandler: ProxyResHandlerWithBody = async ( throw new Error("Expected body to be an object"); } - res.status(200).json({ ...body, proxy: body.proxy }); + let newBody = body; + if (req.inboundApi === "mistral-text" && req.outboundApi === "mistral-ai") { + newBody = transformMistralTextToMistralChat(body); + } + + res.status(200).json({ ...newBody, proxy: body.proxy }); }; +export function transformMistralTextToMistralChat(textBody: any) { + return { + ...textBody, + choices: [ + { message: { content: textBody.outputs[0].text, role: "assistant" } }, + ], + outputs: undefined, + }; +} + const mistralAIProxy = createQueueMiddleware({ proxyMiddleware: createProxyMiddleware({ target: "https://api.mistral.ai", @@ -133,12 +148,32 @@ mistralAIRouter.get("/v1/models", handleModelRequest); mistralAIRouter.post( "/v1/chat/completions", ipLimiter, - createPreprocessorMiddleware({ - inApi: "mistral-ai", - outApi: "mistral-ai", - service: "mistral-ai", - }), + createPreprocessorMiddleware( + { + inApi: "mistral-ai", + outApi: "mistral-ai", + service: "mistral-ai", + }, + { beforeTransform: [detectMistralInputApi] } + ), mistralAIProxy ); +/** + * We can't determine if a request is Mistral text or chat just from the path + * because they both use the same endpoint. We need to check the request body + * for either `messages` or `prompt`. + * @param req + */ +export function detectMistralInputApi(req: Request) { + const { messages, prompt } = req.body; + if (messages) { + req.inboundApi = "mistral-ai"; + req.outboundApi = "mistral-ai"; + } else if (prompt) { + req.inboundApi = "mistral-text"; + req.outboundApi = "mistral-text"; + } +} + export const mistralAI = mistralAIRouter; diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index 9932ef2..069f0ab 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -1,44 +1,55 @@ -import express, { Request, Response, NextFunction } from "express"; -import { gatekeeper } from "./gatekeeper"; -import { checkRisuToken } from "./check-risu-token"; -import { openai } from "./openai"; -import { openaiImage } from "./openai-image"; +import express from "express"; +import { addV1 } from "./add-v1"; import { anthropic } from "./anthropic"; +import { aws } from "./aws"; +import { azure } from "./azure"; +import { checkRisuToken } from "./check-risu-token"; +import { gatekeeper } from "./gatekeeper"; +import { gcp } from "./gcp"; import { googleAI } from "./google-ai"; import { mistralAI } from "./mistral-ai"; -import { aws } from "./aws"; -import { gcp } from "./gcp"; -import { azure } from "./azure"; +import { openai } from "./openai"; +import { openaiImage } from "./openai-image"; import { sendErrorToClient } from "./middleware/response/error-generator"; const proxyRouter = express.Router(); + +// Remove `expect: 100-continue` header from requests due to incompatibility +// with node-http-proxy. proxyRouter.use((req, _res, next) => { if (req.headers.expect) { - // node-http-proxy does not like it when clients send `expect: 100-continue` - // and will stall. none of the upstream APIs use this header anyway. delete req.headers.expect; } next(); }); + +// Apply body parsers. proxyRouter.use( express.json({ limit: "100mb" }), express.urlencoded({ extended: true, limit: "100mb" }) ); + +// Apply auth/rate limits. proxyRouter.use(gatekeeper); proxyRouter.use(checkRisuToken); + +// Initialize request queue metadata. proxyRouter.use((req, _res, next) => { req.startTime = Date.now(); req.retryCount = 0; next(); }); + +// Proxy endpoints. proxyRouter.use("/openai", addV1, openai); proxyRouter.use("/openai-image", addV1, openaiImage); proxyRouter.use("/anthropic", addV1, anthropic); proxyRouter.use("/google-ai", addV1, googleAI); proxyRouter.use("/mistral-ai", addV1, mistralAI); -proxyRouter.use("/aws/claude", addV1, aws); +proxyRouter.use("/aws", aws); proxyRouter.use("/gcp/claude", addV1, gcp); proxyRouter.use("/azure/openai", addV1, azure); + // Redirect browser requests to the homepage. proxyRouter.get("*", (req, res, next) => { const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); @@ -48,7 +59,8 @@ proxyRouter.get("*", (req, res, next) => { next(); } }); -// Handle 404s. + +// Send a fake client error if user specifies an invalid proxy endpoint. proxyRouter.use((req, res) => { sendErrorToClient({ req, @@ -69,11 +81,3 @@ proxyRouter.use((req, res) => { }); export { proxyRouter as proxyRouter }; - -function addV1(req: Request, res: Response, next: NextFunction) { - // Clients don't consistently use the /v1 prefix so we'll add it for them. - if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) { - req.url = `/v1${req.url}`; - } - next(); -} diff --git a/src/service-info.ts b/src/service-info.ts index 5f3752c..5996e99 100644 --- a/src/service-info.ts +++ b/src/service-info.ts @@ -3,8 +3,6 @@ import { AnthropicKey, AwsBedrockKey, GcpKey, - AzureOpenAIKey, - GoogleAIKey, keyPool, OpenAIKey, } from "./shared/key-management"; @@ -26,21 +24,14 @@ import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats"; import { getUniqueIps } from "./proxy/rate-limit"; import { assertNever } from "./shared/utils"; import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue"; -import { MistralAIKey } from "./shared/key-management/mistral-ai/provider"; const CACHE_TTL = 2000; type KeyPoolKey = ReturnType[0]; const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey => k.service === "openai"; -const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey => - k.service === "azure"; const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey => k.service === "anthropic"; -const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey => - k.service === "google-ai"; -const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey => - k.service === "mistral-ai"; const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws"; const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp"; @@ -54,14 +45,15 @@ type ModelAggregates = { overQuota?: number; pozzed?: number; awsLogged?: number; - awsSonnet?: number; - awsSonnet35?: number; - awsHaiku?: number; + // needed to disambugiate aws-claude family's variants + awsClaude2?: number; + awsSonnet3?: number; + awsSonnet3_5?: number; + awsHaiku: number; gcpSonnet?: number; gcpSonnet35?: number; gcpHaiku?: number; queued: number; - queueTime: string; tokens: number; }; /** All possible combinations of model family and aggregate type. */ @@ -93,14 +85,10 @@ type AnthropicInfo = BaseFamilyInfo & { }; type AwsInfo = BaseFamilyInfo & { privacy?: string; - sonnetKeys?: number; - sonnet35Keys?: number; - haikuKeys?: number; + enabledVariants?: string; }; type GcpInfo = BaseFamilyInfo & { - sonnetKeys?: number; - sonnet35Keys?: number; - haikuKeys?: number; + enabledVariants?: string; }; // prettier-ignore @@ -108,12 +96,10 @@ export type ServiceInfo = { uptime: number; endpoints: { openai?: string; - openai2?: string; anthropic?: string; - "anthropic-claude-3"?: string; "google-ai"?: string; "mistral-ai"?: string; - aws?: string; + "aws"?: string; gcp?: string; azure?: string; "openai-image"?: string; @@ -151,7 +137,6 @@ export type ServiceInfo = { const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { openai: { openai: `%BASE%/openai`, - openai2: `%BASE%/openai/turbo-instruct`, "openai-image": `%BASE%/openai-image`, }, anthropic: { @@ -164,7 +149,8 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { "mistral-ai": `%BASE%/mistral-ai`, }, aws: { - aws: `%BASE%/aws/claude`, + "aws-claude": `%BASE%/aws/claude`, + "aws-mistral": `%BASE%/aws/mistral`, }, gcp: { gcp: `%BASE%/gcp/claude`, @@ -175,7 +161,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { }, }; -const modelStats = new Map(); +const familyStats = new Map(); const serviceStats = new Map(); let cachedInfo: ServiceInfo | undefined; @@ -192,7 +178,7 @@ export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo { .concat("turbo") ); - modelStats.clear(); + familyStats.clear(); serviceStats.clear(); keys.forEach(addKeyToAggregates); @@ -311,150 +297,102 @@ function increment( ) { map.set(key, (map.get(key) || 0) + delta); } +const addToService = increment.bind(null, serviceStats); +const addToFamily = increment.bind(null, familyStats); function addKeyToAggregates(k: KeyPoolKey) { - increment(serviceStats, "proompts", k.promptCount); - increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0); - increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0); - increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0); - increment( - serviceStats, - "mistral-ai__keys", - k.service === "mistral-ai" ? 1 : 0 - ); - increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0); - increment(serviceStats, "gcp__keys", k.service === "gcp" ? 1 : 0); - increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0); + addToService("proompts", k.promptCount); + addToService("openai__keys", k.service === "openai" ? 1 : 0); + addToService("anthropic__keys", k.service === "anthropic" ? 1 : 0); + addToService("google-ai__keys", k.service === "google-ai" ? 1 : 0); + addToService("mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0); + addToService("aws__keys", k.service === "aws" ? 1 : 0); + addToService("gcp__keys", k.service === "gcp" ? 1 : 0); + addToService("azure__keys", k.service === "azure" ? 1 : 0); let sumTokens = 0; let sumCost = 0; + const incrementGenericFamilyStats = (f: ModelFamily) => { + const tokens = (k as any)[`${f}Tokens`]; + sumTokens += tokens; + sumCost += getTokenCostUsd(f, tokens); + addToFamily(`${f}__tokens`, tokens); + addToFamily(`${f}__revoked`, k.isRevoked ? 1 : 0); + addToFamily(`${f}__active`, k.isDisabled ? 0 : 1); + }; + switch (k.service) { case "openai": if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type"); - increment( - serviceStats, - "openai__uncheckedKeys", - Boolean(k.lastChecked) ? 0 : 1 - ); - + addToService("openai__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1); k.modelFamilies.forEach((f) => { - const tokens = k[`${f}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(f, tokens); - increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0); - increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0); + incrementGenericFamilyStats(f); + addToFamily(`${f}__trial`, k.isTrial ? 1 : 0); + addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0); }); break; - case "azure": - if (!keyIsAzureKey(k)) throw new Error("Invalid key type"); - k.modelFamilies.forEach((f) => { - const tokens = k[`${f}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(f, tokens); - increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - }); - break; - case "anthropic": { + case "anthropic": if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type"); + addToService("anthropic__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1); k.modelFamilies.forEach((f) => { - const tokens = k[`${f}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(f, tokens); - increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__trial`, k.tier === "free" ? 1 : 0); - increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0); - increment(modelStats, `${f}__pozzed`, k.isPozzed ? 1 : 0); - }); - increment( - serviceStats, - "anthropic__uncheckedKeys", - Boolean(k.lastChecked) ? 0 : 1 - ); - break; - } - case "google-ai": { - if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type"); - k.modelFamilies.forEach((family) => { - const tokens = k[`${family}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(family, tokens); - increment(modelStats, `${family}__tokens`, tokens); - increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); - increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); + incrementGenericFamilyStats(f); + addToFamily(`${f}__trial`, k.tier === "free" ? 1 : 0); + addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0); + addToFamily(`${f}__pozzed`, k.isPozzed ? 1 : 0); }); break; - } - case "mistral-ai": { - if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type"); - k.modelFamilies.forEach((f) => { - const tokens = k[`${f}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(f, tokens); - increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); - }); - break; - } + case "aws": { if (!keyIsAwsKey(k)) throw new Error("Invalid key type"); - k.modelFamilies.forEach((f) => { - const tokens = k[`${f}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(f, tokens); - increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); - }); - increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0); - increment(modelStats, `aws-claude__awsSonnet35`, k.sonnet35Enabled ? 1 : 0); - increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0); - + k.modelFamilies.forEach(incrementGenericFamilyStats); + if (!k.isDisabled) { + // Don't add revoked keys to available AWS variants + k.modelIds.forEach((id) => { + if (id.includes("claude-3-sonnet")) { + addToFamily(`aws-claude__awsSonnet3`, 1); + } else if (id.includes("claude-3-5-sonnet")) { + addToFamily(`aws-claude__awsSonnet3_5`, 1); + } else if (id.includes("claude-3-haiku")) { + addToFamily(`aws-claude__awsHaiku`, 1); + } else if (id.includes("claude-v2")) { + addToFamily(`aws-claude__awsClaude2`, 1); + } + }); + } // Ignore revoked keys for aws logging stats, but include keys where the // logging status is unknown. const countAsLogged = k.lastChecked && !k.isDisabled && k.awsLoggingStatus === "enabled"; - increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0); + addToFamily(`aws-claude__awsLogged`, countAsLogged ? 1 : 0); break; } - case "gcp": { + case "gcp": if (!keyIsGcpKey(k)) throw new Error("Invalid key type"); - k.modelFamilies.forEach((f) => { - const tokens = k[`${f}Tokens`]; - sumTokens += tokens; - sumCost += getTokenCostUsd(f, tokens); - increment(modelStats, `${f}__tokens`, tokens); - increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0); - increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1); - }); - increment(modelStats, `gcp-claude__gcpSonnet`, k.sonnetEnabled ? 1 : 0); - increment(modelStats, `gcp-claude__gcpSonnet35`, k.sonnet35Enabled ? 1 : 0); - increment(modelStats, `gcp-claude__gcpHaiku`, k.haikuEnabled ? 1 : 0); + k.modelFamilies.forEach(incrementGenericFamilyStats); + // TODO: add modelIds to GcpKey + break; + // These services don't have any additional stats to track. + case "azure": + case "google-ai": + case "mistral-ai": + k.modelFamilies.forEach(incrementGenericFamilyStats); break; - } default: assertNever(k.service); } - increment(serviceStats, "tokens", sumTokens); - increment(serviceStats, "tokenCost", sumCost); + addToService("tokens", sumTokens); + addToService("tokenCost", sumCost); } function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { - const tokens = modelStats.get(`${family}__tokens`) || 0; + const tokens = familyStats.get(`${family}__tokens`) || 0; const cost = getTokenCostUsd(family, tokens); let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = { usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`, - activeKeys: modelStats.get(`${family}__active`) || 0, - revokedKeys: modelStats.get(`${family}__revoked`) || 0, + activeKeys: familyStats.get(`${family}__active`) || 0, + revokedKeys: familyStats.get(`${family}__revoked`) || 0, }; // Add service-specific stats to the info object. @@ -462,8 +400,8 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { const service = MODEL_FAMILY_SERVICE[family]; switch (service) { case "openai": - info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0; - info.trialKeys = modelStats.get(`${family}__trial`) || 0; + info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0; + info.trialKeys = familyStats.get(`${family}__trial`) || 0; // Delete trial/revoked keys for non-turbo families. // Trials are turbo 99% of the time, and if a key is invalid we don't @@ -474,16 +412,25 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { } break; case "anthropic": - info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0; - info.trialKeys = modelStats.get(`${family}__trial`) || 0; - info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0; + info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0; + info.trialKeys = familyStats.get(`${family}__trial`) || 0; + info.prefilledKeys = familyStats.get(`${family}__pozzed`) || 0; break; case "aws": if (family === "aws-claude") { - info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0; - info.sonnet35Keys = modelStats.get(`${family}__awsSonnet35`) || 0; - info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0; - const logged = modelStats.get(`${family}__awsLogged`) || 0; + const logged = familyStats.get(`${family}__awsLogged`) || 0; + const variants = new Set(); + if (familyStats.get(`${family}__awsClaude2`) || 0) + variants.add("claude2"); + if (familyStats.get(`${family}__awsSonnet3`) || 0) + variants.add("sonnet3"); + if (familyStats.get(`${family}__awsSonnet3_5`) || 0) + variants.add("sonnet3.5"); + if (familyStats.get(`${family}__awsHaiku`) || 0) + variants.add("haiku"); + info.enabledVariants = variants.size + ? `${Array.from(variants).join(",")}` + : undefined; if (logged > 0) { info.privacy = config.allowAwsLogging ? `AWS logging verification inactive. Prompts could be logged.` @@ -493,9 +440,8 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { break; case "gcp": if (family === "gcp-claude") { - info.sonnetKeys = modelStats.get(`${family}__gcpSonnet`) || 0; - info.sonnet35Keys = modelStats.get(`${family}__gcpSonnet35`) || 0; - info.haikuKeys = modelStats.get(`${family}__gcpHaiku`) || 0; + // TODO: implement + info.enabledVariants = "not implemented"; } break; } diff --git a/src/shared/api-schemas/index.ts b/src/shared/api-schemas/index.ts index 598bf23..8ccd5de 100644 --- a/src/shared/api-schemas/index.ts +++ b/src/shared/api-schemas/index.ts @@ -21,7 +21,11 @@ import { GoogleAIV1GenerateContentSchema, transformOpenAIToGoogleAI, } from "./google-ai"; -import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai"; +import { + MistralAIV1ChatCompletionsSchema, + MistralAIV1TextCompletionsSchema, + transformMistralChatToText, +} from "./mistral-ai"; export { OpenAIChatMessage } from "./openai"; export { @@ -49,6 +53,7 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = { "openai->openai-text": transformOpenAIToOpenAIText, "openai->openai-image": transformOpenAIToOpenAIImage, "openai->google-ai": transformOpenAIToGoogleAI, + "mistral-ai->mistral-text": transformMistralChatToText, }; export const API_REQUEST_VALIDATORS: Record> = { @@ -59,4 +64,5 @@ export const API_REQUEST_VALIDATORS: Record> = { "openai-image": OpenAIV1ImagesGenerationSchema, "google-ai": GoogleAIV1GenerateContentSchema, "mistral-ai": MistralAIV1ChatCompletionsSchema, + "mistral-text": MistralAIV1TextCompletionsSchema, }; diff --git a/src/shared/api-schemas/mistral-ai.ts b/src/shared/api-schemas/mistral-ai.ts index d67f246..a1e4956 100644 --- a/src/shared/api-schemas/mistral-ai.ts +++ b/src/shared/api-schemas/mistral-ai.ts @@ -1,15 +1,35 @@ import { z } from "zod"; import { OPENAI_OUTPUT_MAX } from "./openai"; +import { Template } from "@huggingface/jinja"; +import { APIFormatTransformer } from "./index"; +import { logger } from "../../logger"; + +const MistralChatMessageSchema = z.object({ + role: z.enum(["system", "user", "assistant", "tool"]), // TODO: implement tools + content: z.string(), + prefix: z.boolean().optional(), +}); + +const MistralMessagesSchema = z.array(MistralChatMessageSchema).refine( + (input) => { + const prefixIdx = input.findIndex((msg) => Boolean(msg.prefix)); + if (prefixIdx === -1) return true; // no prefix messages + const lastIdx = input.length - 1; + const lastMsg = input[lastIdx]; + return prefixIdx === lastIdx && lastMsg.role === "assistant"; + }, + { + message: + "`prefix` can only be set to `true` on the last message, and only for an assistant message.", + } +); // https://docs.mistral.ai/api#operation/createChatCompletion -export const MistralAIV1ChatCompletionsSchema = z.object({ +const BaseMistralAIV1CompletionsSchema = z.object({ model: z.string(), - messages: z.array( - z.object({ - role: z.enum(["system", "user", "assistant"]), - content: z.string(), - }) - ), + // One must be provided, checked in a refinement + messages: MistralMessagesSchema.optional(), + prompt: z.string().optional(), temperature: z.number().optional().default(0.7), top_p: z.number().optional().default(1), max_tokens: z.coerce @@ -18,12 +38,48 @@ export const MistralAIV1ChatCompletionsSchema = z.object({ .nullish() .transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)), stream: z.boolean().optional().default(false), + // Mistral docs say that `stop` can be a string or array but AWS Mistral + // blows up if a string is passed. We must convert it to an array. + stop: z + .union([z.string(), z.array(z.string())]) + .optional() + .default([]) + .transform((v) => (Array.isArray(v) ? v : [v])), + random_seed: z.number().int().min(0).optional(), + response_format: z.enum(["text", "json_object"]).optional().default("text"), safe_prompt: z.boolean().optional().default(false), - random_seed: z.number().int().optional(), }); -export type MistralAIChatMessage = z.infer< - typeof MistralAIV1ChatCompletionsSchema ->["messages"][0]; + +export const MistralAIV1ChatCompletionsSchema = + BaseMistralAIV1CompletionsSchema.and( + z.object({ messages: MistralMessagesSchema }) + ); +export const MistralAIV1TextCompletionsSchema = + BaseMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() })); + +/* + Slightly more strict version that only allows a subset of the parameters. AWS + Mistral helpfully returns no details if unsupported parameters are passed so + this list comes from trial and error as of 2024-08-12. +*/ +const BaseAWSMistralAIV1CompletionsSchema = + BaseMistralAIV1CompletionsSchema.pick({ + temperature: true, + top_p: true, + max_tokens: true, + stop: true, + random_seed: true, + // response_format: true, + // safe_prompt: true, + }).strip(); +export const AWSMistralV1ChatCompletionsSchema = + BaseAWSMistralAIV1CompletionsSchema.and( + z.object({ messages: MistralMessagesSchema }) + ); +export const AWSMistralV1TextCompletionsSchema = + BaseAWSMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() })); + +export type MistralAIChatMessage = z.infer; export function fixMistralPrompt( messages: MistralAIChatMessage[] @@ -31,12 +87,11 @@ export function fixMistralPrompt( // Mistral uses OpenAI format but has some additional requirements: // - Only one system message per request, and it must be the first message if // present. - // - Final message must be a user message. + // - Final message must be a user message, unless it has `prefix: true`. // - Cannot have multiple messages from the same role in a row. // While frontends should be able to handle this, we can fix it here in the // meantime. - - return messages.reduce((acc, msg) => { + const fixed = messages.reduce((acc, msg) => { if (acc.length === 0) { acc.push(msg); return acc; @@ -57,4 +112,54 @@ export function fixMistralPrompt( } return acc; }, []); + + // If the last message is an assistant message, mark it as a prefix. An + // assistant message at the end of the conversation without `prefix: true` + // results in an error. + if (fixed[fixed.length - 1].role === "assistant") { + fixed[fixed.length - 1].prefix = true; + } + return fixed; } + +let jinjaTemplate: Template; +let renderTemplate: (messages: MistralAIChatMessage[]) => string; +function renderMistralPrompt(messages: MistralAIChatMessage[]) { + if (!jinjaTemplate) { + logger.warn("Lazy loading mistral chat template..."); + const { chatTemplate, bosToken, eosToken } = + require("./templates/mistral-template").MISTRAL_TEMPLATE; + jinjaTemplate = new Template(chatTemplate); + renderTemplate = (messages) => + jinjaTemplate.render({ + messages, + bos_token: bosToken, + eos_token: eosToken, + }); + } + + return renderTemplate(messages); +} + +/** + * Attempts to convert a Mistral chat completions request to a text completions, + * using the official prompt template published by Mistral. + */ +export const transformMistralChatToText: APIFormatTransformer< + typeof MistralAIV1TextCompletionsSchema +> = async (req) => { + const { body } = req; + const result = MistralAIV1ChatCompletionsSchema.safeParse(body); + if (!result.success) { + req.log.warn( + { issues: result.error.issues, body }, + "Invalid Mistral chat completions request" + ); + throw result.error; + } + + const { messages, ...rest } = result.data; + const prompt = renderMistralPrompt(messages); + + return { ...rest, prompt, messages: undefined }; +}; diff --git a/src/shared/api-schemas/templates/mistral-template.ts b/src/shared/api-schemas/templates/mistral-template.ts new file mode 100644 index 0000000..bb3de1f --- /dev/null +++ b/src/shared/api-schemas/templates/mistral-template.ts @@ -0,0 +1,36 @@ +export const MISTRAL_TEMPLATE = { + bosToken: "", + eosToken: "", + chatTemplate: `"{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "assistant" %} + {%- if loop.last and message.prefix is defined and message.prefix %} + {{- " " + message["content"] }} + {%- else %} + {{- " " + message["content"] + eos_token}} + {%- endif %} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %}`, +}; diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts index 057686e..bcd8dfb 100644 --- a/src/shared/key-management/anthropic/provider.ts +++ b/src/shared/key-management/anthropic/provider.ts @@ -1,5 +1,5 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; import { AnthropicModelFamily, getClaudeModelFamily } from "../../models"; @@ -23,10 +23,6 @@ type AnthropicKeyUsage = { export interface AnthropicKey extends Key, AnthropicKeyUsage { readonly service: "anthropic"; readonly modelFamilies: AnthropicModelFamily[]; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** The time until which this key is rate limited. */ - rateLimitedUntil: number; /** * Whether this key requires a special preamble. For unclear reasons, some * Anthropic keys will throw an error if the prompt does not begin with a @@ -217,22 +213,7 @@ export class AnthropicKeyProvider implements KeyProvider { key[`${getClaudeModelFamily(model)}Tokens`] += tokens; } - public getLockoutPeriod() { - const activeKeys = this.keys.filter((k) => !k.isDisabled); - // Don't lock out if there are no keys available or the queue will stall. - // Just let it through so the add-key middleware can throw an error. - if (activeKeys.length === 0) return 0; - - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); - const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; - - if (anyNotRateLimited) return 0; - - // If all keys are rate-limited, return the time until the first key is - // ready. - return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); - } + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); /** * This is called when we receive a 429, which means there are already five diff --git a/src/shared/key-management/aws/checker.ts b/src/shared/key-management/aws/checker.ts index de22c64..62d57ed 100644 --- a/src/shared/key-management/aws/checker.ts +++ b/src/shared/key-management/aws/checker.ts @@ -5,9 +5,21 @@ import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios"; import { URL } from "url"; import { KeyCheckerBase } from "../key-checker-base"; import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider"; -import { AwsBedrockModelFamily } from "../../models"; +import { getAwsBedrockModelFamily } from "../../models"; import { config } from "../../../config"; +const KNOWN_MODEL_IDS = [ + "anthropic.claude-v2", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", + "mistral.mistral-large-2407-v1:0", + "mistral.mistral-small-2402-v1:0", // Seems to return 400 +]; const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes const AMZ_HOST = @@ -47,36 +59,20 @@ export class AwsKeyChecker extends KeyCheckerBase { } protected async testKeyOrFail(key: AwsBedrockKey) { - // Only check models on startup. For now all models must be available to - // the proxy because we don't route requests to different keys. - let checks: Promise[] = []; const isInitialCheck = !key.lastChecked; - if (isInitialCheck) { - checks = [ - this.invokeModel("anthropic.claude-v2", key), - this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key), - this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key), - this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key), - this.invokeModel("anthropic.claude-3-5-sonnet-20240620-v1:0", key), - ]; - } - - checks.unshift(this.checkLoggingConfiguration(key)); - - const [_logging, claudeV2, sonnet, haiku, opus, sonnet35] = - await Promise.all(checks); - - this.log.debug( - { key: key.hash, _logging, claudeV2, sonnet, haiku, opus, sonnet35 }, - "AWS model tests complete." - ); if (isInitialCheck) { - const families: AwsBedrockModelFamily[] = []; - if (claudeV2 || sonnet || sonnet35 || haiku) families.push("aws-claude"); - if (opus) families.push("aws-claude-opus"); + const checks = await Promise.all( + KNOWN_MODEL_IDS.map(async (model) => { + const success = await this.invokeModel(model, key); + return { model, success }; + }) + ); + const modelIds = checks + .filter(({ success }) => success) + .map(({ model }) => model); - if (families.length === 0) { + if (modelIds.length === 0) { this.log.warn( { key: key.hash }, "Key does not have access to any models; disabling." @@ -85,20 +81,19 @@ export class AwsKeyChecker extends KeyCheckerBase { } this.updateKey(key.hash, { - sonnetEnabled: sonnet, - haikuEnabled: haiku, - sonnet35Enabled: sonnet35, - modelFamilies: families, + modelIds, + modelFamilies: Array.from( + new Set(modelIds.map(getAwsBedrockModelFamily)) + ), }); } this.log.info( { key: key.hash, - sonnet, - haiku, - families: key.modelFamilies, logged: key.awsLoggingStatus, + families: key.modelFamilies, + models: key.modelIds, }, "Checked key." ); @@ -169,7 +164,19 @@ export class AwsKeyChecker extends KeyCheckerBase { * key has access to the model, false if it does not. Throws an error if the * key is disabled. */ - private async invokeModel(model: string, key: AwsBedrockKey) { + private async invokeModel( + model: string, + key: AwsBedrockKey + ): Promise { + if (model.includes("claude")) { + return this.testClaudeModel(key, model); + } else if (model.includes("mistral")) { + return this.testMistralModel(key, model); + } + throw new Error("AwsKeyChecker#invokeModel: no implementation for model"); + } + + private async testClaudeModel(key: AwsBedrockKey, model: string): Promise { const creds = AwsKeyChecker.getCredentialsFromKey(key); // This is not a valid invocation payload, but a 400 response indicates that // the principal at least has permission to invoke the model. @@ -196,14 +203,15 @@ export class AwsKeyChecker extends KeyCheckerBase { const errorType = (headers["x-amzn-errortype"] as string).split(":")[0]; const errorMessage = data?.message; - // We only allow one type of 403 error, and we only allow it for one model. + // This message indicates the key is valid but this particular model is not + // accessible. Other 403s may indicate the key is not usable. if ( status === 403 && errorMessage?.match(/access to the model with the specified model ID/) ) { return false; } - + // ResourceNotFound typically indicates that the tested model cannot be used // on the configured region for this set of credentials. if (status === 404) { @@ -219,6 +227,58 @@ export class AwsKeyChecker extends KeyCheckerBase { const correctErrorType = errorType === "ValidationException"; const correctErrorMessage = errorMessage?.match(/max_tokens/); if (!correctErrorType || !correctErrorMessage) { + this.log.debug( + { key: key.hash, model, errorType, data, status }, + "AWS InvokeModel test unsuccessful." + ); + return false; + } + + this.log.debug( + { key: key.hash, model, errorType, data, status }, + "AWS InvokeModel test successful." + ); + return true; + } + + private async testMistralModel(key: AwsBedrockKey, model: string): Promise { + const creds = AwsKeyChecker.getCredentialsFromKey(key); + + const payload = { + max_tokens: -1, + prompt: "[INST] What is your favourite condiment? [/INST]", + } + const config: AxiosRequestConfig = { + method: "POST", + url: POST_INVOKE_MODEL_URL(creds.region, model), + data: payload, + validateStatus: (status) => [400, 403, 404].includes(status), + headers: { + "content-type": "application/json", + accept: "*/*", + } + }; + await AwsKeyChecker.signRequestForAws(config, key); + const response = await axios.request(config); + const { data, status, headers } = response; + const errorType = (headers["x-amzn-errortype"] as string).split(":")[0]; + const errorMessage = data?.message; + + if (status === 403 || status === 404) { + this.log.debug( + { key: key.hash, model, errorType, data, status }, + "AWS InvokeModel test returned 403 or 404." + ); + return false; + } + + const isBadRequest = status === 400; + const isValidationError = errorMessage?.match(/validation error/i); + if (isBadRequest && !isValidationError) { + this.log.debug( + { key: key.hash, model, errorType, data, status, headers }, + "AWS InvokeModel test returned 400 but not a validation error." + ); return false; } diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index fe23809..76d10b5 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -1,10 +1,11 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models"; -import { AwsKeyChecker } from "./checker"; import { PaymentRequiredError } from "../../errors"; +import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; +import { AwsKeyChecker } from "./checker"; type AwsBedrockKeyUsage = { [K in AwsBedrockModelFamily as `${K}Tokens`]: number; @@ -13,10 +14,6 @@ type AwsBedrockKeyUsage = { export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage { readonly service: "aws"; readonly modelFamilies: AwsBedrockModelFamily[]; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** The time until which this key is rate limited. */ - rateLimitedUntil: number; /** * The confirmed logging status of this key. This is "unknown" until we * receive a response from the AWS API. Keys which are logged, or not @@ -24,9 +21,7 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage { * set. */ awsLoggingStatus: "unknown" | "disabled" | "enabled"; - sonnetEnabled: boolean; - haikuEnabled: boolean; - sonnet35Enabled: boolean; + modelIds: string[]; } /** @@ -76,11 +71,13 @@ export class AwsBedrockKeyProvider implements KeyProvider { .digest("hex") .slice(0, 8)}`, lastChecked: 0, - sonnetEnabled: true, - haikuEnabled: false, - sonnet35Enabled: false, + modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"], ["aws-claudeTokens"]: 0, ["aws-claude-opusTokens"]: 0, + ["aws-mistral-tinyTokens"]: 0, + ["aws-mistral-smallTokens"]: 0, + ["aws-mistral-mediumTokens"]: 0, + ["aws-mistral-largeTokens"]: 0, }; this.keys.push(newKey); } @@ -99,41 +96,35 @@ export class AwsBedrockKeyProvider implements KeyProvider { } public get(model: string) { + let neededVariantId = model; + // This function accepts both Anthropic/Mistral IDs and AWS IDs. + // Generally all AWS model IDs are supersets of the original vendor IDs. + // Claude 2 is the only model that breaks this convention; Anthropic calls + // it claude-2 but AWS calls it claude-v2. + if (model.includes("claude-2")) neededVariantId = "claude-v2"; const neededFamily = getAwsBedrockModelFamily(model); - // this is a horrible mess - // each of these should be separate model families, but adding model - // families is not low enough friction for the rate at which aws claude - // model variants are added. - const needsSonnet35 = - model.includes("claude-3-5-sonnet") && neededFamily === "aws-claude"; - const needsSonnet = - !needsSonnet35 && - model.includes("sonnet") && - neededFamily === "aws-claude"; - const needsHaiku = model.includes("haiku") && neededFamily === "aws-claude"; - const availableKeys = this.keys.filter((k) => { - const isNotLogged = k.awsLoggingStatus !== "enabled"; + // Select keys which return ( + // are enabled !k.isDisabled && - (isNotLogged || config.allowAwsLogging) && - (k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under aws-claude, while opus is not - (k.haikuEnabled || !needsHaiku) && - (k.sonnet35Enabled || !needsSonnet35) && - k.modelFamilies.includes(neededFamily) + // are not logged, unless policy allows it + (config.allowAwsLogging || k.awsLoggingStatus !== "enabled") && + // have access to the model family we need + k.modelFamilies.includes(neededFamily) && + // have access to the specific variant we need + k.modelIds.some((m) => m.includes(neededVariantId)) ); }); this.log.debug( { - model, - neededFamily, - needsSonnet, - needsHaiku, - needsSonnet35, - availableKeys: availableKeys.length, + requestedModel: model, + selectedVariant: neededVariantId, + selectedFamily: neededFamily, totalKeys: this.keys.length, + availableKeys: availableKeys.length, }, "Selecting AWS key" ); @@ -144,30 +135,8 @@ export class AwsBedrockKeyProvider implements KeyProvider { ); } - // (largely copied from the OpenAI provider, without trial key support) - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + const selectedKey = prioritizeKeys(availableKeys)[0]; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } @@ -195,22 +164,7 @@ export class AwsBedrockKeyProvider implements KeyProvider { key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens; } - public getLockoutPeriod() { - // TODO: same exact behavior for three providers, should be refactored - const activeKeys = this.keys.filter((k) => !k.isDisabled); - // Don't lock out if there are no keys available or the queue will stall. - // Just let it through so the add-key middleware can throw an error. - if (activeKeys.length === 0) return 0; - - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); - const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; - - if (anyNotRateLimited) return 0; - - // If all keys are rate-limited, return time until the first key is ready. - return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); - } + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); /** * This is called when we receive a 429, which means there are already five diff --git a/src/shared/key-management/azure/provider.ts b/src/shared/key-management/azure/provider.ts index 681a6ed..28439fa 100644 --- a/src/shared/key-management/azure/provider.ts +++ b/src/shared/key-management/azure/provider.ts @@ -1,10 +1,13 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; import { config } from "../../../config"; -import { PaymentRequiredError } from "../../errors"; import { logger } from "../../../logger"; -import type { AzureOpenAIModelFamily } from "../../models"; -import { getAzureOpenAIModelFamily } from "../../models"; +import { PaymentRequiredError } from "../../errors"; +import { + AzureOpenAIModelFamily, + getAzureOpenAIModelFamily, +} from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; import { AzureOpenAIKeyChecker } from "./checker"; type AzureOpenAIKeyUsage = { @@ -14,10 +17,6 @@ type AzureOpenAIKeyUsage = { export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage { readonly service: "azure"; readonly modelFamilies: AzureOpenAIModelFamily[]; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** The time until which this key is rate limited. */ - rateLimitedUntil: number; contentFiltering: boolean; } @@ -105,30 +104,8 @@ export class AzureOpenAIKeyProvider implements KeyProvider { ); } - // (largely copied from the OpenAI provider, without trial key support) - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + const selectedKey = prioritizeKeys(availableKeys)[0]; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } @@ -156,26 +133,7 @@ export class AzureOpenAIKeyProvider implements KeyProvider { key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens; } - // TODO: all of this shit is duplicate code - - public getLockoutPeriod(family: AzureOpenAIModelFamily) { - const activeKeys = this.keys.filter( - (key) => !key.isDisabled && key.modelFamilies.includes(family) - ); - - // Don't lock out if there are no keys available or the queue will stall. - // Just let it through so the add-key middleware can throw an error. - if (activeKeys.length === 0) return 0; - - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); - const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; - - if (anyNotRateLimited) return 0; - - // If all keys are rate-limited, return time until the first key is ready. - return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); - } + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); /** * This is called when we receive a 429, which means there are already five diff --git a/src/shared/key-management/gcp/provider.ts b/src/shared/key-management/gcp/provider.ts index 8e9c9ab..e3f72ef 100644 --- a/src/shared/key-management/gcp/provider.ts +++ b/src/shared/key-management/gcp/provider.ts @@ -1,10 +1,11 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { GcpModelFamily, getGcpModelFamily } from "../../models"; -import { GcpKeyChecker } from "./checker"; import { PaymentRequiredError } from "../../errors"; +import { GcpModelFamily, getGcpModelFamily } from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; +import { GcpKeyChecker } from "./checker"; type GcpKeyUsage = { [K in GcpModelFamily as `${K}Tokens`]: number; @@ -13,10 +14,6 @@ type GcpKeyUsage = { export interface GcpKey extends Key, GcpKeyUsage { readonly service: "gcp"; readonly modelFamilies: GcpModelFamily[]; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** The time until which this key is rate limited. */ - rateLimitedUntil: number; sonnetEnabled: boolean; haikuEnabled: boolean; sonnet35Enabled: boolean; @@ -134,30 +131,8 @@ export class GcpKeyProvider implements KeyProvider { ); } - // (largely copied from the OpenAI provider, without trial key support) - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + const selectedKey = prioritizeKeys(availableKeys)[0]; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } @@ -185,22 +160,7 @@ export class GcpKeyProvider implements KeyProvider { key[`${getGcpModelFamily(model)}Tokens`] += tokens; } - public getLockoutPeriod() { - // TODO: same exact behavior for three providers, should be refactored - const activeKeys = this.keys.filter((k) => !k.isDisabled); - // Don't lock out if there are no keys available or the queue will stall. - // Just let it through so the add-key middleware can throw an error. - if (activeKeys.length === 0) return 0; - - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); - const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; - - if (anyNotRateLimited) return 0; - - // If all keys are rate-limited, return time until the first key is ready. - return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); - } + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); /** * This is called when we receive a 429, which means there are already five diff --git a/src/shared/key-management/google-ai/provider.ts b/src/shared/key-management/google-ai/provider.ts index 47695f5..a7abfc1 100644 --- a/src/shared/key-management/google-ai/provider.ts +++ b/src/shared/key-management/google-ai/provider.ts @@ -1,9 +1,10 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models"; import { PaymentRequiredError } from "../../errors"; +import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; import { GoogleAIKeyChecker } from "./checker"; // Note that Google AI is not the same as Vertex AI, both are provided by @@ -28,10 +29,6 @@ type GoogleAIKeyUsage = { export interface GoogleAIKey extends Key, GoogleAIKeyUsage { readonly service: "google-ai"; readonly modelFamilies: GoogleAIModelFamily[]; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** The time until which this key is rate limited. */ - rateLimitedUntil: number; /** All detected model IDs on this key. */ modelIds: string[]; } @@ -112,29 +109,10 @@ export class GoogleAIKeyProvider implements KeyProvider { throw new PaymentRequiredError("No Google AI keys available"); } - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); + const keysByPriority = prioritizeKeys(availableKeys); const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } @@ -162,22 +140,7 @@ export class GoogleAIKeyProvider implements KeyProvider { key[`${getGoogleAIModelFamily(model)}Tokens`] += tokens; } - public getLockoutPeriod() { - const activeKeys = this.keys.filter((k) => !k.isDisabled); - // Don't lock out if there are no keys available or the queue will stall. - // Just let it through so the add-key middleware can throw an error. - if (activeKeys.length === 0) return 0; - - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); - const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; - - if (anyNotRateLimited) return 0; - - // If all keys are rate-limited, return the time until the first key is - // ready. - return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); - } + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); /** * This is called when we receive a 429, which means there are already five diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 67dfad4..1dd58dc 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -9,7 +9,8 @@ export type APIFormat = | "anthropic-chat" // Anthropic's newer messages array format | "anthropic-text" // Legacy flat string prompt format | "google-ai" - | "mistral-ai"; + | "mistral-ai" + | "mistral-text" export interface Key { /** The API key itself. Never log this, use `hash` instead. */ @@ -30,6 +31,10 @@ export interface Key { lastChecked: number; /** Hash of the key, for logging and to find the key in the pool. */ hash: string; + /** The time at which this key was last rate limited. */ + rateLimitedAt: number; + /** The time until which this key is rate limited. */ + rateLimitedUntil: number; } /* @@ -58,10 +63,32 @@ export interface KeyProvider { recheck(): void; } +export function createGenericGetLockoutPeriod( + getKeys: () => T[] +) { + return function (this: unknown, family?: ModelFamily): number { + const keys = getKeys(); + const activeKeys = keys.filter( + (k) => !k.isDisabled && (!family || k.modelFamilies.includes(family)) + ); + + if (activeKeys.length === 0) return 0; + + const now = Date.now(); + const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); + const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; + + if (anyNotRateLimited) return 0; + + return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); + }; +} + export const keyPool = new KeyPool(); export { AnthropicKey } from "./anthropic/provider"; -export { OpenAIKey } from "./openai/provider"; -export { GoogleAIKey } from "././google-ai/provider"; export { AwsBedrockKey } from "./aws/provider"; export { GcpKey } from "./gcp/provider"; export { AzureOpenAIKey } from "./azure/provider"; +export { GoogleAIKey } from "././google-ai/provider"; +export { MistralAIKey } from "./mistral-ai/provider"; +export { OpenAIKey } from "./openai/provider"; diff --git a/src/shared/key-management/mistral-ai/provider.ts b/src/shared/key-management/mistral-ai/provider.ts index 83785f8..a8460e6 100644 --- a/src/shared/key-management/mistral-ai/provider.ts +++ b/src/shared/key-management/mistral-ai/provider.ts @@ -1,10 +1,11 @@ import crypto from "crypto"; -import { Key, KeyProvider } from ".."; import { config } from "../../../config"; import { logger } from "../../../logger"; -import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models"; -import { MistralAIKeyChecker } from "./checker"; import { HttpError } from "../../errors"; +import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models"; +import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; +import { prioritizeKeys } from "../prioritize-keys"; +import { MistralAIKeyChecker } from "./checker"; type MistralAIKeyUsage = { [K in MistralAIModelFamily as `${K}Tokens`]: number; @@ -13,10 +14,6 @@ type MistralAIKeyUsage = { export interface MistralAIKey extends Key, MistralAIKeyUsage { readonly service: "mistral-ai"; readonly modelFamilies: MistralAIModelFamily[]; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** The time until which this key is rate limited. */ - rateLimitedUntil: number; } /** @@ -98,30 +95,8 @@ export class MistralAIKeyProvider implements KeyProvider { throw new HttpError(402, "No Mistral AI keys available"); } - // (largely copied from the OpenAI provider, without trial key support) - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. If all keys were rate limited recently, select the least-recently - // rate limited key. - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; - const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = now; + const selectedKey = prioritizeKeys(availableKeys)[0]; + selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); return { ...selectedKey }; } @@ -150,22 +125,7 @@ export class MistralAIKeyProvider implements KeyProvider { key[`${family}Tokens`] += tokens; } - public getLockoutPeriod() { - const activeKeys = this.keys.filter((k) => !k.isDisabled); - // Don't lock out if there are no keys available or the queue will stall. - // Just let it through so the add-key middleware can throw an error. - if (activeKeys.length === 0) return 0; - - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); - const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; - - if (anyNotRateLimited) return 0; - - // If all keys are rate-limited, return the time until the first key is - // ready. - return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now)); - } + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); /** * This is called when we receive a 429, which means there are already five diff --git a/src/shared/key-management/openai/provider.ts b/src/shared/key-management/openai/provider.ts index 809262c..528f029 100644 --- a/src/shared/key-management/openai/provider.ts +++ b/src/shared/key-management/openai/provider.ts @@ -26,8 +26,6 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage { isTrial: boolean; /** Set when key check returns a non-transient 429. */ isOverQuota: boolean; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; /** * Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a * number. @@ -111,6 +109,7 @@ export class OpenAIKeyProvider implements KeyProvider { .digest("hex") .slice(0, 8)}`, rateLimitedAt: 0, + rateLimitedUntil: 0, rateLimitRequestsReset: 0, rateLimitTokensReset: 0, turboTokens: 0, diff --git a/src/shared/key-management/prioritize-keys.ts b/src/shared/key-management/prioritize-keys.ts new file mode 100644 index 0000000..cf52995 --- /dev/null +++ b/src/shared/key-management/prioritize-keys.ts @@ -0,0 +1,24 @@ +import { Key } from "./index"; + +export function prioritizeKeys(keys: T[]) { + // Sorts keys from highest priority to lowest priority, where priority is: + // 1. Keys which are not rate limited + // a. If all keys were rate limited recently, select the least-recently + // rate limited key. + // 2. Keys which have not been used in the longest time + + const now = Date.now(); + + return keys.sort((a, b) => { + const aRateLimited = now - a.rateLimitedAt < a.rateLimitedUntil; + const bRateLimited = now - b.rateLimitedAt < b.rateLimitedUntil; + + if (aRateLimited && !bRateLimited) return 1; + if (!aRateLimited && bRateLimited) return -1; + if (aRateLimited && bRateLimited) { + return a.rateLimitedAt - b.rateLimitedAt; + } + + return a.lastUsed - b.lastUsed; + }); +} diff --git a/src/shared/models.ts b/src/shared/models.ts index 2004374..66255bf 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -32,7 +32,9 @@ export type MistralAIModelFamily = // mistral changes their model classes frequently so these no longer // correspond to specific models. consider them rough pricing tiers. "mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large"; -export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus"; +export type AwsBedrockModelFamily = `aws-${ + | AnthropicModelFamily + | MistralAIModelFamily}`; export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus"; export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`; export type ModelFamily = @@ -64,6 +66,10 @@ export const MODEL_FAMILIES = (( "mistral-large", "aws-claude", "aws-claude-opus", + "aws-mistral-tiny", + "aws-mistral-small", + "aws-mistral-medium", + "aws-mistral-large", "gcp-claude", "gcp-claude-opus", "azure-turbo", @@ -99,6 +105,10 @@ export const MODEL_FAMILY_SERVICE: { "claude-opus": "anthropic", "aws-claude": "aws", "aws-claude-opus": "aws", + "aws-mistral-tiny": "aws", + "aws-mistral-small": "aws", + "aws-mistral-medium": "aws", + "aws-mistral-large": "aws", "gcp-claude": "gcp", "gcp-claude-opus": "gcp", "azure-turbo": "azure", @@ -180,8 +190,16 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily { } export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily { - if (model.includes("opus")) return "aws-claude-opus"; - return "aws-claude"; + // remove vendor and version from AWS model ids + // 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620' + const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d+)?(:\d+)*$/, "$2"); + + if (["claude", "anthropic"].some((x) => model.includes(x))) { + return `aws-${getClaudeModelFamily(deAwsified)}`; + } else if (model.includes("tral")) { + return `aws-${getMistralAIModelFamily(deAwsified)}`; + } + return `aws-claude`; } export function getGcpModelFamily(model: string): GcpModelFamily { @@ -223,8 +241,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily { const model = req.body.model ?? "gpt-3.5-turbo"; let modelFamily: ModelFamily; - // Weird special case for AWS/GCP/Azure because they serve multiple models from - // different vendors, even if currently only one is supported. + // Weird special case for AWS/GCP/Azure because they serve models with + // different API formats, so the outbound API alone is not sufficient to + // determine the partition. if (req.service === "aws") { modelFamily = getAwsBedrockModelFamily(model); } else if (req.service === "gcp") { @@ -246,6 +265,7 @@ export function getModelFamilyForRequest(req: Request): ModelFamily { modelFamily = getGoogleAIModelFamily(model); break; case "mistral-ai": + case "mistral-text": modelFamily = getMistralAIModelFamily(model); break; default: diff --git a/src/shared/tokenization/tokenizer.ts b/src/shared/tokenization/tokenizer.ts index 1b03f3d..864f0bb 100644 --- a/src/shared/tokenization/tokenizer.ts +++ b/src/shared/tokenization/tokenizer.ts @@ -47,9 +47,9 @@ type GoogleAIChatTokenCountRequest = { }; type MistralAIChatTokenCountRequest = { - prompt: MistralAIChatMessage[]; + prompt: string | MistralAIChatMessage[]; completion?: never; - service: "mistral-ai"; + service: "mistral-ai" | "mistral-text"; }; type FlatPromptTokenCountRequest = { @@ -128,6 +128,7 @@ export async function countTokens({ tokenization_duration_ms: getElapsedMs(time), }; case "mistral-ai": + case "mistral-text": return { ...getMistralAITokenCount(prompt ?? completion), tokenization_duration_ms: getElapsedMs(time), diff --git a/src/shared/users/user-store.ts b/src/shared/users/user-store.ts index 9718c90..d6f3707 100644 --- a/src/shared/users/user-store.ts +++ b/src/shared/users/user-store.ts @@ -431,6 +431,7 @@ function getModelFamilyForQuotaUsage( case "google-ai": return getGoogleAIModelFamily(model); case "mistral-ai": + case "mistral-text": return getMistralAIModelFamily(model); default: assertNever(api);