adds /aws/mistral endpoint

This commit is contained in:
nai-degen
2024-08-11 13:09:08 -05:00
parent 9e5a660ef5
commit 2d8e1dac13
4 changed files with 132 additions and 15 deletions
+15 -8
View File
@@ -1,7 +1,6 @@
import { Request, RequestHandler, Router } from "express"; import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware"; import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid"; import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger"; import { logger } from "../logger";
import { createQueueMiddleware } from "./queue"; import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit"; import { ipLimiter } from "./rate-limit";
@@ -16,7 +15,10 @@ import {
ProxyResHandlerWithBody, ProxyResHandlerWithBody,
createOnProxyResHandler, createOnProxyResHandler,
} from "./middleware/response"; } from "./middleware/response";
import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic"; import {
transformAnthropicChatResponseToAnthropicText,
transformAnthropicChatResponseToOpenAI,
} from "./anthropic";
/** Only used for non-streaming requests. */ /** Only used for non-streaming requests. */
const awsResponseHandler: ProxyResHandlerWithBody = async ( const awsResponseHandler: ProxyResHandlerWithBody = async (
@@ -87,7 +89,7 @@ function transformAwsTextResponseToOpenAI(
}; };
} }
const awsProxy = createQueueMiddleware({ const awsClaudeProxy = createQueueMiddleware({
beforeProxy: signAwsRequest, beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({ proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten", target: "bad-target-will-be-rewritten",
@@ -152,7 +154,12 @@ const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
const awsClaudeRouter = Router(); const awsClaudeRouter = Router();
// Native(ish) Anthropic text completion endpoint. // Native(ish) Anthropic text completion endpoint.
awsClaudeRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy); awsClaudeRouter.post(
"/v1/complete",
ipLimiter,
preprocessAwsTextRequest,
awsClaudeProxy
);
// Native Anthropic chat completion endpoint. // Native Anthropic chat completion endpoint.
awsClaudeRouter.post( awsClaudeRouter.post(
"/v1/messages", "/v1/messages",
@@ -161,7 +168,7 @@ awsClaudeRouter.post(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" }, { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] } { afterTransform: [maybeReassignModel] }
), ),
awsProxy awsClaudeProxy
); );
// OpenAI-to-AWS Anthropic compatibility endpoint. // OpenAI-to-AWS Anthropic compatibility endpoint.
@@ -169,7 +176,7 @@ awsClaudeRouter.post(
"/v1/chat/completions", "/v1/chat/completions",
ipLimiter, ipLimiter,
preprocessOpenAICompatRequest, preprocessOpenAICompatRequest,
awsProxy awsClaudeProxy
); );
/** /**
@@ -179,7 +186,7 @@ awsClaudeRouter.post(
* - frontends sending Anthropic model names that AWS doesn't recognize * - frontends sending Anthropic model names that AWS doesn't recognize
* - frontends sending OpenAI model names because they expect the proxy to * - frontends sending OpenAI model names because they expect the proxy to
* translate them * translate them
* *
* If client sends AWS model ID it will be used verbatim. Otherwise, various * If client sends AWS model ID it will be used verbatim. Otherwise, various
* strategies are used to try to map a non-AWS model name to AWS model ID. * strategies are used to try to map a non-AWS model name to AWS model ID.
*/ */
@@ -212,7 +219,7 @@ function maybeReassignModel(req: Request) {
req.body.model = "anthropic.claude-instant-v1"; req.body.model = "anthropic.claude-instant-v1";
return; return;
} }
const ver = minor ? `${major}.${minor}` : major; const ver = minor ? `${major}.${minor}` : major;
switch (ver) { switch (ver) {
case "1": case "1":
+100
View File
@@ -0,0 +1,100 @@
import { Request } from "express";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { createQueueMiddleware } from "./queue";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
signAwsRequest,
} from "./middleware/request";
import { createProxyMiddleware } from "http-proxy-middleware";
import { logger } from "../logger";
import { handleProxyError } from "./middleware/common";
import { Router } from "express";
import { ipLimiter } from "./rate-limit";
const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
// AWS does not always confirm the model in the response, so we have to add it
if (!body.model && req.body.model) {
body.model = req.body.model;
}
res.status(200).json({ ...body, proxy: body.proxy });
};
const awsMistralProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]),
error: handleProxyError,
},
}),
});
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an AWS model, use it as-is
if (model.startsWith("mistral.")) {
return;
}
// Mistral 7B Instruct
else if (model.includes("7b")) {
req.body.model = "mistral.mistral-7b-instruct-v0:2";
}
// Mistral 8x7B Instruct
else if (model.includes("8x7b")) {
req.body.model = "mistral.mixtral-8x7b-instruct-v0:1";
}
// Mistral Large (Feb 2024)
else if (model.includes("large-2402")) {
req.body.model = "mistral.mistral-large-2402-v1:0";
}
// Mistral Large 2 (July 2024)
else if (model.includes("large")) {
req.body.model = "mistral.mistral-large-2407-v1:0";
}
// Mistral Small (Feb 2024)
else if (model.includes("small")) {
req.body.model = "mistral.mistral-small-2402-v1:0";
} else {
throw new Error(`Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock`);
}
}
const nativeMistralChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const awsMistralRouter = Router();
awsMistralRouter.post(
"/v1/chat/completions",
ipLimiter,
nativeMistralChatPreprocessor,
awsMistralProxy
);
export const awsMistral = awsMistralRouter;
+17 -4
View File
@@ -2,12 +2,14 @@
import { Request, Response, Router } from "express"; import { Request, Response, Router } from "express";
import { config } from "../config"; import { config } from "../config";
import { awsClaude } from "./aws-claude";
import { addV1 } from "./add-v1"; import { addV1 } from "./add-v1";
import { awsClaude } from "./aws-claude";
import { awsMistral } from "./aws-mistral";
import { AwsBedrockKey, keyPool } from "../shared/key-management";
const awsRouter = Router(); const awsRouter = Router();
awsRouter.use("/claude", addV1, awsClaude); awsRouter.use("/claude", addV1, awsClaude);
// awsRouter.use("/mistral", addV1, awsMistralRouter); awsRouter.use("/mistral", addV1, awsMistral);
awsRouter.get("/:vendor?/models", handleModelsRequest); awsRouter.get("/:vendor?/models", handleModelsRequest);
const MODELS_CACHE_TTL = 10000; const MODELS_CACHE_TTL = 10000;
@@ -19,6 +21,12 @@ function handleModelsRequest(req: Request, res: Response) {
return res.json(modelsCache); return res.json(modelsCache);
} }
const availableModelIds = new Set<string>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "aws") continue;
(key as AwsBedrockKey).modelIds.forEach((id) => availableModelIds.add(id));
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
const models = [ const models = [
"anthropic.claude-v2", "anthropic.claude-v2",
@@ -32,7 +40,9 @@ function handleModelsRequest(req: Request, res: Response) {
"mistral.mistral-large-2402-v1:0", "mistral.mistral-large-2402-v1:0",
"mistral.mistral-large-2407-v1:0", "mistral.mistral-large-2407-v1:0",
"mistral.mistral-small-2402-v1:0", "mistral.mistral-small-2402-v1:0",
].map((id) => { ]
.filter((id) => availableModelIds.has(id))
.map((id) => {
const vendor = id.match(/^(.*)\./)?.[1]; const vendor = id.match(/^(.*)\./)?.[1];
return { return {
id, id,
@@ -47,7 +57,10 @@ function handleModelsRequest(req: Request, res: Response) {
const requestedVendor = req.params.vendor; const requestedVendor = req.params.vendor;
const vendor = requestedVendor === "claude" ? "anthropic" : requestedVendor; const vendor = requestedVendor === "claude" ? "anthropic" : requestedVendor;
modelsCache = { object: "list", data: models.filter((m) => m.root === vendor) }; modelsCache = {
object: "list",
data: models.filter((m) => m.root === vendor),
};
modelsCacheTime = new Date().getTime(); modelsCacheTime = new Date().getTime();
return res.json(modelsCache); return res.json(modelsCache);
@@ -72,9 +72,6 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
.slice(0, 8)}`, .slice(0, 8)}`,
lastChecked: 0, lastChecked: 0,
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"], modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
// sonnetEnabled: true,
// haikuEnabled: false,
// sonnet35Enabled: false,
["aws-claudeTokens"]: 0, ["aws-claudeTokens"]: 0,
["aws-claude-opusTokens"]: 0, ["aws-claude-opusTokens"]: 0,
["aws-mistral-tinyTokens"]: 0, ["aws-mistral-tinyTokens"]: 0,