Files
OAI-Proxy/src/proxy/aws-mistral.ts
T

111 lines
3.2 KiB
TypeScript

import { Request } from "express";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { createQueueMiddleware } from "./queue";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
signAwsRequest,
} from "./middleware/request";
import { createProxyMiddleware } from "http-proxy-middleware";
import { logger } from "../logger";
import { handleProxyError } from "./middleware/common";
import { Router } from "express";
import { ipLimiter } from "./rate-limit";
import { detectMistralInputApi, transformMistralTextToMistralChat } from "./mistral-ai";
const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
if (req.inboundApi === "mistral-ai" && req.outboundApi === "mistral-text") {
newBody = transformMistralTextToMistralChat(body);
}
// AWS does not always confirm the model in the response, so we have to add it
if (!newBody.model && req.body.model) {
newBody.model = req.body.model;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
const awsMistralProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]),
error: handleProxyError,
},
}),
});
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an AWS model, use it as-is
if (model.startsWith("mistral.")) {
return;
}
// Mistral 7B Instruct
else if (model.includes("7b")) {
req.body.model = "mistral.mistral-7b-instruct-v0:2";
}
// Mistral 8x7B Instruct
else if (model.includes("8x7b")) {
req.body.model = "mistral.mixtral-8x7b-instruct-v0:1";
}
// Mistral Large (Feb 2024)
else if (model.includes("large-2402")) {
req.body.model = "mistral.mistral-large-2402-v1:0";
}
// Mistral Large 2 (July 2024)
else if (model.includes("large")) {
req.body.model = "mistral.mistral-large-2407-v1:0";
}
// Mistral Small (Feb 2024)
else if (model.includes("small")) {
req.body.model = "mistral.mistral-small-2402-v1:0";
} else {
throw new Error(
`Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock`
);
}
}
const nativeMistralChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" },
{
beforeTransform: [detectMistralInputApi],
afterTransform: [maybeReassignModel],
}
);
const awsMistralRouter = Router();
awsMistralRouter.post(
"/v1/chat/completions",
ipLimiter,
nativeMistralChatPreprocessor,
awsMistralProxy
);
export const awsMistral = awsMistralRouter;