diff --git a/src/proxy/aws-claude.ts b/src/proxy/aws-claude.ts index a3ab783..a00e4e6 100644 --- a/src/proxy/aws-claude.ts +++ b/src/proxy/aws-claude.ts @@ -1,7 +1,6 @@ import { Request, RequestHandler, Router } from "express"; import { createProxyMiddleware } from "http-proxy-middleware"; import { v4 } from "uuid"; -import { config } from "../config"; import { logger } from "../logger"; import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit"; @@ -16,7 +15,10 @@ import { ProxyResHandlerWithBody, createOnProxyResHandler, } from "./middleware/response"; -import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic"; +import { + transformAnthropicChatResponseToAnthropicText, + transformAnthropicChatResponseToOpenAI, +} from "./anthropic"; /** Only used for non-streaming requests. */ const awsResponseHandler: ProxyResHandlerWithBody = async ( @@ -87,7 +89,7 @@ function transformAwsTextResponseToOpenAI( }; } -const awsProxy = createQueueMiddleware({ +const awsClaudeProxy = createQueueMiddleware({ beforeProxy: signAwsRequest, proxyMiddleware: createProxyMiddleware({ target: "bad-target-will-be-rewritten", @@ -152,7 +154,12 @@ const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { const awsClaudeRouter = Router(); // Native(ish) Anthropic text completion endpoint. -awsClaudeRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy); +awsClaudeRouter.post( + "/v1/complete", + ipLimiter, + preprocessAwsTextRequest, + awsClaudeProxy +); // Native Anthropic chat completion endpoint. awsClaudeRouter.post( "/v1/messages", @@ -161,7 +168,7 @@ awsClaudeRouter.post( { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" }, { afterTransform: [maybeReassignModel] } ), - awsProxy + awsClaudeProxy ); // OpenAI-to-AWS Anthropic compatibility endpoint. @@ -169,7 +176,7 @@ awsClaudeRouter.post( "/v1/chat/completions", ipLimiter, preprocessOpenAICompatRequest, - awsProxy + awsClaudeProxy ); /** @@ -179,7 +186,7 @@ awsClaudeRouter.post( * - frontends sending Anthropic model names that AWS doesn't recognize * - frontends sending OpenAI model names because they expect the proxy to * translate them - * + * * If client sends AWS model ID it will be used verbatim. Otherwise, various * strategies are used to try to map a non-AWS model name to AWS model ID. */ @@ -212,7 +219,7 @@ function maybeReassignModel(req: Request) { req.body.model = "anthropic.claude-instant-v1"; return; } - + const ver = minor ? `${major}.${minor}` : major; switch (ver) { case "1": diff --git a/src/proxy/aws-mistral.ts b/src/proxy/aws-mistral.ts index e69de29..4de559a 100644 --- a/src/proxy/aws-mistral.ts +++ b/src/proxy/aws-mistral.ts @@ -0,0 +1,100 @@ +import { Request } from "express"; +import { + createOnProxyResHandler, + ProxyResHandlerWithBody, +} from "./middleware/response"; +import { createQueueMiddleware } from "./queue"; +import { + createOnProxyReqHandler, + createPreprocessorMiddleware, + finalizeSignedRequest, + signAwsRequest, +} from "./middleware/request"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { logger } from "../logger"; +import { handleProxyError } from "./middleware/common"; +import { Router } from "express"; +import { ipLimiter } from "./rate-limit"; + +const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + // AWS does not always confirm the model in the response, so we have to add it + if (!body.model && req.body.model) { + body.model = req.body.model; + } + + res.status(200).json({ ...body, proxy: body.proxy }); +}; + +const awsMistralProxy = createQueueMiddleware({ + beforeProxy: signAwsRequest, + proxyMiddleware: createProxyMiddleware({ + target: "bad-target-will-be-rewritten", + router: ({ signedRequest }) => { + if (!signedRequest) throw new Error("Must sign request before proxying"); + return `${signedRequest.protocol}//${signedRequest.hostname}`; + }, + changeOrigin: true, + selfHandleResponse: true, + logger, + on: { + proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), + proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]), + error: handleProxyError, + }, + }), +}); + +function maybeReassignModel(req: Request) { + const model = req.body.model; + + // If it looks like an AWS model, use it as-is + if (model.startsWith("mistral.")) { + return; + } + // Mistral 7B Instruct + else if (model.includes("7b")) { + req.body.model = "mistral.mistral-7b-instruct-v0:2"; + } + // Mistral 8x7B Instruct + else if (model.includes("8x7b")) { + req.body.model = "mistral.mixtral-8x7b-instruct-v0:1"; + } + // Mistral Large (Feb 2024) + else if (model.includes("large-2402")) { + req.body.model = "mistral.mistral-large-2402-v1:0"; + } + // Mistral Large 2 (July 2024) + else if (model.includes("large")) { + req.body.model = "mistral.mistral-large-2407-v1:0"; + } + // Mistral Small (Feb 2024) + else if (model.includes("small")) { + req.body.model = "mistral.mistral-small-2402-v1:0"; + } else { + throw new Error(`Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock`); + } +} + +const nativeMistralChatPreprocessor = createPreprocessorMiddleware( + { inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +const awsMistralRouter = Router(); +awsMistralRouter.post( + "/v1/chat/completions", + ipLimiter, + nativeMistralChatPreprocessor, + awsMistralProxy +); + +export const awsMistral = awsMistralRouter; diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index 9e26603..2a54a13 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -2,12 +2,14 @@ import { Request, Response, Router } from "express"; import { config } from "../config"; -import { awsClaude } from "./aws-claude"; import { addV1 } from "./add-v1"; +import { awsClaude } from "./aws-claude"; +import { awsMistral } from "./aws-mistral"; +import { AwsBedrockKey, keyPool } from "../shared/key-management"; const awsRouter = Router(); awsRouter.use("/claude", addV1, awsClaude); -// awsRouter.use("/mistral", addV1, awsMistralRouter); +awsRouter.use("/mistral", addV1, awsMistral); awsRouter.get("/:vendor?/models", handleModelsRequest); const MODELS_CACHE_TTL = 10000; @@ -19,6 +21,12 @@ function handleModelsRequest(req: Request, res: Response) { return res.json(modelsCache); } + const availableModelIds = new Set(); + for (const key of keyPool.list()) { + if (key.isDisabled || key.service !== "aws") continue; + (key as AwsBedrockKey).modelIds.forEach((id) => availableModelIds.add(id)); + } + // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html const models = [ "anthropic.claude-v2", @@ -32,7 +40,9 @@ function handleModelsRequest(req: Request, res: Response) { "mistral.mistral-large-2402-v1:0", "mistral.mistral-large-2407-v1:0", "mistral.mistral-small-2402-v1:0", - ].map((id) => { + ] + .filter((id) => availableModelIds.has(id)) + .map((id) => { const vendor = id.match(/^(.*)\./)?.[1]; return { id, @@ -47,7 +57,10 @@ function handleModelsRequest(req: Request, res: Response) { const requestedVendor = req.params.vendor; const vendor = requestedVendor === "claude" ? "anthropic" : requestedVendor; - modelsCache = { object: "list", data: models.filter((m) => m.root === vendor) }; + modelsCache = { + object: "list", + data: models.filter((m) => m.root === vendor), + }; modelsCacheTime = new Date().getTime(); return res.json(modelsCache); diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index 97d320d..76d10b5 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -72,9 +72,6 @@ export class AwsBedrockKeyProvider implements KeyProvider { .slice(0, 8)}`, lastChecked: 0, modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"], - // sonnetEnabled: true, - // haikuEnabled: false, - // sonnet35Enabled: false, ["aws-claudeTokens"]: 0, ["aws-claude-opusTokens"]: 0, ["aws-mistral-tinyTokens"]: 0,