adds /aws/mistral endpoint

This commit is contained in:
nai-degen
2024-08-11 13:09:08 -05:00
parent 9e5a660ef5
commit 2d8e1dac13
4 changed files with 132 additions and 15 deletions
+15 -8
View File
@@ -1,7 +1,6 @@
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
@@ -16,7 +15,10 @@ import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic";
import {
transformAnthropicChatResponseToAnthropicText,
transformAnthropicChatResponseToOpenAI,
} from "./anthropic";
/** Only used for non-streaming requests. */
const awsResponseHandler: ProxyResHandlerWithBody = async (
@@ -87,7 +89,7 @@ function transformAwsTextResponseToOpenAI(
};
}
const awsProxy = createQueueMiddleware({
const awsClaudeProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
@@ -152,7 +154,12 @@ const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
const awsClaudeRouter = Router();
// Native(ish) Anthropic text completion endpoint.
awsClaudeRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy);
awsClaudeRouter.post(
"/v1/complete",
ipLimiter,
preprocessAwsTextRequest,
awsClaudeProxy
);
// Native Anthropic chat completion endpoint.
awsClaudeRouter.post(
"/v1/messages",
@@ -161,7 +168,7 @@ awsClaudeRouter.post(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
),
awsProxy
awsClaudeProxy
);
// OpenAI-to-AWS Anthropic compatibility endpoint.
@@ -169,7 +176,7 @@ awsClaudeRouter.post(
"/v1/chat/completions",
ipLimiter,
preprocessOpenAICompatRequest,
awsProxy
awsClaudeProxy
);
/**
@@ -179,7 +186,7 @@ awsClaudeRouter.post(
* - frontends sending Anthropic model names that AWS doesn't recognize
* - frontends sending OpenAI model names because they expect the proxy to
* translate them
*
*
* If client sends AWS model ID it will be used verbatim. Otherwise, various
* strategies are used to try to map a non-AWS model name to AWS model ID.
*/
@@ -212,7 +219,7 @@ function maybeReassignModel(req: Request) {
req.body.model = "anthropic.claude-instant-v1";
return;
}
const ver = minor ? `${major}.${minor}` : major;
switch (ver) {
case "1":
+100
View File
@@ -0,0 +1,100 @@
import { Request } from "express";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { createQueueMiddleware } from "./queue";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
signAwsRequest,
} from "./middleware/request";
import { createProxyMiddleware } from "http-proxy-middleware";
import { logger } from "../logger";
import { handleProxyError } from "./middleware/common";
import { Router } from "express";
import { ipLimiter } from "./rate-limit";
const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
// AWS does not always confirm the model in the response, so we have to add it
if (!body.model && req.body.model) {
body.model = req.body.model;
}
res.status(200).json({ ...body, proxy: body.proxy });
};
const awsMistralProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]),
error: handleProxyError,
},
}),
});
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an AWS model, use it as-is
if (model.startsWith("mistral.")) {
return;
}
// Mistral 7B Instruct
else if (model.includes("7b")) {
req.body.model = "mistral.mistral-7b-instruct-v0:2";
}
// Mistral 8x7B Instruct
else if (model.includes("8x7b")) {
req.body.model = "mistral.mixtral-8x7b-instruct-v0:1";
}
// Mistral Large (Feb 2024)
else if (model.includes("large-2402")) {
req.body.model = "mistral.mistral-large-2402-v1:0";
}
// Mistral Large 2 (July 2024)
else if (model.includes("large")) {
req.body.model = "mistral.mistral-large-2407-v1:0";
}
// Mistral Small (Feb 2024)
else if (model.includes("small")) {
req.body.model = "mistral.mistral-small-2402-v1:0";
} else {
throw new Error(`Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock`);
}
}
const nativeMistralChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const awsMistralRouter = Router();
awsMistralRouter.post(
"/v1/chat/completions",
ipLimiter,
nativeMistralChatPreprocessor,
awsMistralProxy
);
export const awsMistral = awsMistralRouter;
+17 -4
View File
@@ -2,12 +2,14 @@
import { Request, Response, Router } from "express";
import { config } from "../config";
import { awsClaude } from "./aws-claude";
import { addV1 } from "./add-v1";
import { awsClaude } from "./aws-claude";
import { awsMistral } from "./aws-mistral";
import { AwsBedrockKey, keyPool } from "../shared/key-management";
const awsRouter = Router();
awsRouter.use("/claude", addV1, awsClaude);
// awsRouter.use("/mistral", addV1, awsMistralRouter);
awsRouter.use("/mistral", addV1, awsMistral);
awsRouter.get("/:vendor?/models", handleModelsRequest);
const MODELS_CACHE_TTL = 10000;
@@ -19,6 +21,12 @@ function handleModelsRequest(req: Request, res: Response) {
return res.json(modelsCache);
}
const availableModelIds = new Set<string>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "aws") continue;
(key as AwsBedrockKey).modelIds.forEach((id) => availableModelIds.add(id));
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
const models = [
"anthropic.claude-v2",
@@ -32,7 +40,9 @@ function handleModelsRequest(req: Request, res: Response) {
"mistral.mistral-large-2402-v1:0",
"mistral.mistral-large-2407-v1:0",
"mistral.mistral-small-2402-v1:0",
].map((id) => {
]
.filter((id) => availableModelIds.has(id))
.map((id) => {
const vendor = id.match(/^(.*)\./)?.[1];
return {
id,
@@ -47,7 +57,10 @@ function handleModelsRequest(req: Request, res: Response) {
const requestedVendor = req.params.vendor;
const vendor = requestedVendor === "claude" ? "anthropic" : requestedVendor;
modelsCache = { object: "list", data: models.filter((m) => m.root === vendor) };
modelsCache = {
object: "list",
data: models.filter((m) => m.root === vendor),
};
modelsCacheTime = new Date().getTime();
return res.json(modelsCache);
@@ -72,9 +72,6 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
.slice(0, 8)}`,
lastChecked: 0,
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
// sonnetEnabled: true,
// haikuEnabled: false,
// sonnet35Enabled: false,
["aws-claudeTokens"]: 0,
["aws-claude-opusTokens"]: 0,
["aws-mistral-tinyTokens"]: 0,