From 9e5a660ef5d7cc2fc5280205ec4c285ed77be492 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Sun, 11 Aug 2024 12:39:01 -0500 Subject: [PATCH] refactors aws endpoint router to split claude/mistral --- src/proxy/add-v1.ts | 9 + src/proxy/aws-claude.ts | 246 +++++++++++++++++++++++++++ src/proxy/aws-mistral.ts | 0 src/proxy/aws.ts | 353 ++++----------------------------------- src/proxy/routes.ts | 44 ++--- 5 files changed, 315 insertions(+), 337 deletions(-) create mode 100644 src/proxy/add-v1.ts create mode 100644 src/proxy/aws-claude.ts create mode 100644 src/proxy/aws-mistral.ts diff --git a/src/proxy/add-v1.ts b/src/proxy/add-v1.ts new file mode 100644 index 0000000..d0d8620 --- /dev/null +++ b/src/proxy/add-v1.ts @@ -0,0 +1,9 @@ +import { NextFunction, Request, Response } from "express"; + +export function addV1(req: Request, res: Response, next: NextFunction) { + // Clients don't consistently use the /v1 prefix so we'll add it for them. + if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) { + req.url = `/v1${req.url}`; + } + next(); +} diff --git a/src/proxy/aws-claude.ts b/src/proxy/aws-claude.ts new file mode 100644 index 0000000..a3ab783 --- /dev/null +++ b/src/proxy/aws-claude.ts @@ -0,0 +1,246 @@ +import { Request, RequestHandler, Router } from "express"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { v4 } from "uuid"; +import { config } from "../config"; +import { logger } from "../logger"; +import { createQueueMiddleware } from "./queue"; +import { ipLimiter } from "./rate-limit"; +import { handleProxyError } from "./middleware/common"; +import { + createPreprocessorMiddleware, + signAwsRequest, + finalizeSignedRequest, + createOnProxyReqHandler, +} from "./middleware/request"; +import { + ProxyResHandlerWithBody, + createOnProxyResHandler, +} from "./middleware/response"; +import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic"; + +/** Only used for non-streaming requests. */ +const awsResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + let newBody = body; + switch (`${req.inboundApi}<-${req.outboundApi}`) { + case "openai<-anthropic-text": + req.log.info("Transforming Anthropic Text back to OpenAI format"); + newBody = transformAwsTextResponseToOpenAI(body, req); + break; + case "openai<-anthropic-chat": + req.log.info("Transforming AWS Anthropic Chat back to OpenAI format"); + newBody = transformAnthropicChatResponseToOpenAI(body); + break; + case "anthropic-text<-anthropic-chat": + req.log.info("Transforming AWS Anthropic Chat back to Text format"); + newBody = transformAnthropicChatResponseToAnthropicText(body); + break; + } + + // AWS does not always confirm the model in the response, so we have to add it + if (!newBody.model && req.body.model) { + newBody.model = req.body.model; + } + + res.status(200).json({ ...newBody, proxy: body.proxy }); +}; + +/** + * Transforms a model response from the Anthropic API to match those from the + * OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This + * is only used for non-streaming requests as streaming requests are handled + * on-the-fly. + */ +function transformAwsTextResponseToOpenAI( + awsBody: Record, + req: Request +): Record { + const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0); + return { + id: "aws-" + v4(), + object: "chat.completion", + created: Date.now(), + model: req.body.model, + usage: { + prompt_tokens: req.promptTokens, + completion_tokens: req.outputTokens, + total_tokens: totalTokens, + }, + choices: [ + { + message: { + role: "assistant", + content: awsBody.completion?.trim(), + }, + finish_reason: awsBody.stop_reason, + index: 0, + }, + ], + }; +} + +const awsProxy = createQueueMiddleware({ + beforeProxy: signAwsRequest, + proxyMiddleware: createProxyMiddleware({ + target: "bad-target-will-be-rewritten", + router: ({ signedRequest }) => { + if (!signedRequest) throw new Error("Must sign request before proxying"); + return `${signedRequest.protocol}//${signedRequest.hostname}`; + }, + changeOrigin: true, + selfHandleResponse: true, + logger, + on: { + proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), + proxyRes: createOnProxyResHandler([awsResponseHandler]), + error: handleProxyError, + }, + }), +}); + +const nativeTextPreprocessor = createPreprocessorMiddleware( + { inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +const textToChatPreprocessor = createPreprocessorMiddleware( + { inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +/** + * Routes text completion prompts to aws anthropic-chat if they need translation + * (claude-3 based models do not support the old text completion endpoint). + */ +const preprocessAwsTextRequest: RequestHandler = (req, res, next) => { + if (req.body.model?.includes("claude-3")) { + textToChatPreprocessor(req, res, next); + } else { + nativeTextPreprocessor(req, res, next); + } +}; + +const oaiToAwsTextPreprocessor = createPreprocessorMiddleware( + { inApi: "openai", outApi: "anthropic-text", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +const oaiToAwsChatPreprocessor = createPreprocessorMiddleware( + { inApi: "openai", outApi: "anthropic-chat", service: "aws" }, + { afterTransform: [maybeReassignModel] } +); + +/** + * Routes an OpenAI prompt to either the legacy Claude text completion endpoint + * or the new Claude chat completion endpoint, based on the requested model. + */ +const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { + if (req.body.model?.includes("claude-3")) { + oaiToAwsChatPreprocessor(req, res, next); + } else { + oaiToAwsTextPreprocessor(req, res, next); + } +}; + +const awsClaudeRouter = Router(); +// Native(ish) Anthropic text completion endpoint. +awsClaudeRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy); +// Native Anthropic chat completion endpoint. +awsClaudeRouter.post( + "/v1/messages", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" }, + { afterTransform: [maybeReassignModel] } + ), + awsProxy +); + +// OpenAI-to-AWS Anthropic compatibility endpoint. +awsClaudeRouter.post( + "/v1/chat/completions", + ipLimiter, + preprocessOpenAICompatRequest, + awsProxy +); + +/** + * Tries to deal with: + * - frontends sending AWS model names even when they want to use the OpenAI- + * compatible endpoint + * - frontends sending Anthropic model names that AWS doesn't recognize + * - frontends sending OpenAI model names because they expect the proxy to + * translate them + * + * If client sends AWS model ID it will be used verbatim. Otherwise, various + * strategies are used to try to map a non-AWS model name to AWS model ID. + */ +function maybeReassignModel(req: Request) { + const model = req.body.model; + + // If it looks like an AWS model, use it as-is + if (model.includes("anthropic.claude")) { + return; + } + + // Anthropic model names can look like: + // - claude-v1 + // - claude-2.1 + // - claude-3-5-sonnet-20240620-v1:0 + const pattern = + /^(claude-)?(instant-)?(v)?(\d+)([.-](\d))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i; + const match = model.match(pattern); + + // If there's no match, fallback to Claude v2 as it is most likely to be + // available on AWS. + if (!match) { + req.body.model = `anthropic.claude-v2:1`; + return; + } + + const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match; + + if (instant) { + req.body.model = "anthropic.claude-instant-v1"; + return; + } + + const ver = minor ? `${major}.${minor}` : major; + switch (ver) { + case "1": + case "1.0": + req.body.model = "anthropic.claude-v1"; + return; + case "2": + case "2.0": + req.body.model = "anthropic.claude-v2"; + return; + case "3": + case "3.0": + if (name.includes("opus")) { + req.body.model = "anthropic.claude-3-opus-20240229-v1:0"; + } else if (name.includes("haiku")) { + req.body.model = "anthropic.claude-3-haiku-20240307-v1:0"; + } else { + req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0"; + } + return; + case "3.5": + req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; + return; + } + + // Fallback to Claude 2.1 + req.body.model = `anthropic.claude-v2:1`; + return; +} + +export const awsClaude = awsClaudeRouter; diff --git a/src/proxy/aws-mistral.ts b/src/proxy/aws-mistral.ts new file mode 100644 index 0000000..e69de29 diff --git a/src/proxy/aws.ts b/src/proxy/aws.ts index c02ff0e..9e26603 100644 --- a/src/proxy/aws.ts +++ b/src/proxy/aws.ts @@ -1,337 +1,56 @@ -import { Request, RequestHandler, Response, Router } from "express"; -import { createProxyMiddleware } from "http-proxy-middleware"; -import { v4 } from "uuid"; +/* Shared code between AWS Claude and AWS Mistral endpoints. */ + +import { Request, Response, Router } from "express"; import { config } from "../config"; -import { logger } from "../logger"; -import { createQueueMiddleware } from "./queue"; -import { ipLimiter } from "./rate-limit"; -import { handleProxyError } from "./middleware/common"; -import { - createPreprocessorMiddleware, - signAwsRequest, - finalizeSignedRequest, - createOnProxyReqHandler, -} from "./middleware/request"; -import { - ProxyResHandlerWithBody, - createOnProxyResHandler, -} from "./middleware/response"; -import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic"; -import { sendErrorToClient } from "./middleware/response/error-generator"; +import { awsClaude } from "./aws-claude"; +import { addV1 } from "./add-v1"; -const LATEST_AWS_V2_MINOR_VERSION = "1"; +const awsRouter = Router(); +awsRouter.use("/claude", addV1, awsClaude); +// awsRouter.use("/mistral", addV1, awsMistralRouter); +awsRouter.get("/:vendor?/models", handleModelsRequest); +const MODELS_CACHE_TTL = 10000; let modelsCache: any = null; let modelsCacheTime = 0; - -const getModelsResponse = () => { - if (new Date().getTime() - modelsCacheTime < 1000 * 60) { - return modelsCache; +function handleModelsRequest(req: Request, res: Response) { + if (!config.awsCredentials) return { object: "list", data: [] }; + if (new Date().getTime() - modelsCacheTime < MODELS_CACHE_TTL) { + return res.json(modelsCache); } - if (!config.awsCredentials) return { object: "list", data: [] }; - // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html - const variants = [ + const models = [ "anthropic.claude-v2", "anthropic.claude-v2:1", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-5-sonnet-20240620-v1:0", "anthropic.claude-3-opus-20240229-v1:0", - ]; + "mistral.mistral-7b-instruct-v0:2", + "mistral.mixtral-8x7b-instruct-v0:1", + "mistral.mistral-large-2402-v1:0", + "mistral.mistral-large-2407-v1:0", + "mistral.mistral-small-2402-v1:0", + ].map((id) => { + const vendor = id.match(/^(.*)\./)?.[1]; + return { + id, + object: "model", + created: new Date().getTime(), + owned_by: vendor, + permission: [], + root: vendor, + parent: null, + }; + }); - const models = variants.map((id) => ({ - id, - object: "model", - created: new Date().getTime(), - owned_by: "anthropic", - permission: [], - root: "claude", - parent: null, - })); - - modelsCache = { object: "list", data: models }; + const requestedVendor = req.params.vendor; + const vendor = requestedVendor === "claude" ? "anthropic" : requestedVendor; + modelsCache = { object: "list", data: models.filter((m) => m.root === vendor) }; modelsCacheTime = new Date().getTime(); - return modelsCache; -}; - -const handleModelRequest: RequestHandler = (_req, res) => { - res.status(200).json(getModelsResponse()); -}; - -/** Only used for non-streaming requests. */ -const awsResponseHandler: ProxyResHandlerWithBody = async ( - _proxyRes, - req, - res, - body -) => { - if (typeof body !== "object") { - throw new Error("Expected body to be an object"); - } - - let newBody = body; - switch (`${req.inboundApi}<-${req.outboundApi}`) { - case "openai<-anthropic-text": - req.log.info("Transforming Anthropic Text back to OpenAI format"); - newBody = transformAwsTextResponseToOpenAI(body, req); - break; - case "openai<-anthropic-chat": - req.log.info("Transforming AWS Anthropic Chat back to OpenAI format"); - newBody = transformAnthropicChatResponseToOpenAI(body); - break; - case "anthropic-text<-anthropic-chat": - req.log.info("Transforming AWS Anthropic Chat back to Text format"); - newBody = transformAnthropicChatResponseToAnthropicText(body); - break; - } - - // AWS does not always confirm the model in the response, so we have to add it - if (!newBody.model && req.body.model) { - newBody.model = req.body.model; - } - - res.status(200).json({ ...newBody, proxy: body.proxy }); -}; - -/** - * Transforms a model response from the Anthropic API to match those from the - * OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This - * is only used for non-streaming requests as streaming requests are handled - * on-the-fly. - */ -function transformAwsTextResponseToOpenAI( - awsBody: Record, - req: Request -): Record { - const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0); - return { - id: "aws-" + v4(), - object: "chat.completion", - created: Date.now(), - model: req.body.model, - usage: { - prompt_tokens: req.promptTokens, - completion_tokens: req.outputTokens, - total_tokens: totalTokens, - }, - choices: [ - { - message: { - role: "assistant", - content: awsBody.completion?.trim(), - }, - finish_reason: awsBody.stop_reason, - index: 0, - }, - ], - }; -} - -const awsProxy = createQueueMiddleware({ - beforeProxy: signAwsRequest, - proxyMiddleware: createProxyMiddleware({ - target: "bad-target-will-be-rewritten", - router: ({ signedRequest }) => { - if (!signedRequest) throw new Error("Must sign request before proxying"); - return `${signedRequest.protocol}//${signedRequest.hostname}`; - }, - changeOrigin: true, - selfHandleResponse: true, - logger, - on: { - proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }), - proxyRes: createOnProxyResHandler([awsResponseHandler]), - error: handleProxyError, - }, - }), -}); - -const nativeTextPreprocessor = createPreprocessorMiddleware( - { inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" }, - { afterTransform: [maybeReassignModel] } -); - -const textToChatPreprocessor = createPreprocessorMiddleware( - { inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" }, - { afterTransform: [maybeReassignModel] } -); - -/** - * Routes text completion prompts to aws anthropic-chat if they need translation - * (claude-3 based models do not support the old text completion endpoint). - */ -const preprocessAwsTextRequest: RequestHandler = (req, res, next) => { - if (req.body.model?.includes("claude-3")) { - textToChatPreprocessor(req, res, next); - } else { - nativeTextPreprocessor(req, res, next); - } -}; - -const oaiToAwsTextPreprocessor = createPreprocessorMiddleware( - { inApi: "openai", outApi: "anthropic-text", service: "aws" }, - { afterTransform: [maybeReassignModel] } -); - -const oaiToAwsChatPreprocessor = createPreprocessorMiddleware( - { inApi: "openai", outApi: "anthropic-chat", service: "aws" }, - { afterTransform: [maybeReassignModel] } -); - -/** - * Routes an OpenAI prompt to either the legacy Claude text completion endpoint - * or the new Claude chat completion endpoint, based on the requested model. - */ -const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { - if (req.body.model?.includes("claude-3")) { - oaiToAwsChatPreprocessor(req, res, next); - } else { - oaiToAwsTextPreprocessor(req, res, next); - } -}; - -const awsRouter = Router(); -awsRouter.get("/v1/models", handleModelRequest); -// Native(ish) Anthropic text completion endpoint. -awsRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy); -// Native Anthropic chat completion endpoint. -awsRouter.post( - "/v1/messages", - ipLimiter, - createPreprocessorMiddleware( - { inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" }, - { afterTransform: [maybeReassignModel] } - ), - awsProxy -); -// Temporary force-Claude3 endpoint -awsRouter.post( - "/v1/sonnet/:action(complete|messages)", - ipLimiter, - handleCompatibilityRequest, - createPreprocessorMiddleware({ - inApi: "anthropic-text", - outApi: "anthropic-chat", - service: "aws", - }), - awsProxy -); - -// OpenAI-to-AWS Anthropic compatibility endpoint. -awsRouter.post( - "/v1/chat/completions", - ipLimiter, - preprocessOpenAICompatRequest, - awsProxy -); - -/** - * Tries to deal with: - * - frontends sending AWS model names even when they want to use the OpenAI- - * compatible endpoint - * - frontends sending Anthropic model names that AWS doesn't recognize - * - frontends sending OpenAI model names because they expect the proxy to - * translate them - * - * If client sends AWS model ID it will be used verbatim. Otherwise, various - * strategies are used to try to map a non-AWS model name to AWS model ID. - */ -function maybeReassignModel(req: Request) { - const model = req.body.model; - - // If it looks like an AWS model, use it as-is - if (model.includes("anthropic.claude")) { - return; - } - - // Anthropic model names can look like: - // - claude-v1 - // - claude-2.1 - // - claude-3-5-sonnet-20240620-v1:0 - const pattern = - /^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i; - const match = model.match(pattern); - - // If there's no match, fallback to Claude v2 as it is most likely to be - // available on AWS. - if (!match) { - req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`; - return; - } - - const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match; - - if (instant) { - req.body.model = "anthropic.claude-instant-v1"; - return; - } - - const ver = minor ? `${major}.${minor}` : major; - switch (ver) { - case "1": - case "1.0": - req.body.model = "anthropic.claude-v1"; - return; - case "2": - case "2.0": - req.body.model = "anthropic.claude-v2"; - return; - case "3": - case "3.0": - if (name.includes("opus")) { - req.body.model = "anthropic.claude-3-opus-20240229-v1:0"; - } else if (name.includes("haiku")) { - req.body.model = "anthropic.claude-3-haiku-20240307-v1:0"; - } else { - req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0"; - } - return; - case "3.5": - req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; - return; - } - - // Fallback to Claude 2.1 - req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`; - return; -} - -export function handleCompatibilityRequest( - req: Request, - res: Response, - next: any -) { - const action = req.params.action; - const alreadyInChatFormat = Boolean(req.body.messages); - const compatModel = "anthropic.claude-3-5-sonnet-20240620-v1:0"; - req.log.info( - { inputModel: req.body.model, compatModel, alreadyInChatFormat }, - "Handling AWS compatibility request" - ); - - if (action === "messages" || alreadyInChatFormat) { - return sendErrorToClient({ - req, - res, - options: { - title: "Unnecessary usage of compatibility endpoint", - message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/aws/claude\` proxy endpoint instead.`, - format: "unknown", - statusCode: 400, - reqId: req.id, - obj: { - requested_endpoint: "/aws/claude/sonnet", - correct_endpoint: "/aws/claude", - }, - }, - }); - } - - req.body.model = compatModel; - next(); + return res.json(modelsCache); } export const aws = awsRouter; diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index 9932ef2..069f0ab 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -1,44 +1,55 @@ -import express, { Request, Response, NextFunction } from "express"; -import { gatekeeper } from "./gatekeeper"; -import { checkRisuToken } from "./check-risu-token"; -import { openai } from "./openai"; -import { openaiImage } from "./openai-image"; +import express from "express"; +import { addV1 } from "./add-v1"; import { anthropic } from "./anthropic"; +import { aws } from "./aws"; +import { azure } from "./azure"; +import { checkRisuToken } from "./check-risu-token"; +import { gatekeeper } from "./gatekeeper"; +import { gcp } from "./gcp"; import { googleAI } from "./google-ai"; import { mistralAI } from "./mistral-ai"; -import { aws } from "./aws"; -import { gcp } from "./gcp"; -import { azure } from "./azure"; +import { openai } from "./openai"; +import { openaiImage } from "./openai-image"; import { sendErrorToClient } from "./middleware/response/error-generator"; const proxyRouter = express.Router(); + +// Remove `expect: 100-continue` header from requests due to incompatibility +// with node-http-proxy. proxyRouter.use((req, _res, next) => { if (req.headers.expect) { - // node-http-proxy does not like it when clients send `expect: 100-continue` - // and will stall. none of the upstream APIs use this header anyway. delete req.headers.expect; } next(); }); + +// Apply body parsers. proxyRouter.use( express.json({ limit: "100mb" }), express.urlencoded({ extended: true, limit: "100mb" }) ); + +// Apply auth/rate limits. proxyRouter.use(gatekeeper); proxyRouter.use(checkRisuToken); + +// Initialize request queue metadata. proxyRouter.use((req, _res, next) => { req.startTime = Date.now(); req.retryCount = 0; next(); }); + +// Proxy endpoints. proxyRouter.use("/openai", addV1, openai); proxyRouter.use("/openai-image", addV1, openaiImage); proxyRouter.use("/anthropic", addV1, anthropic); proxyRouter.use("/google-ai", addV1, googleAI); proxyRouter.use("/mistral-ai", addV1, mistralAI); -proxyRouter.use("/aws/claude", addV1, aws); +proxyRouter.use("/aws", aws); proxyRouter.use("/gcp/claude", addV1, gcp); proxyRouter.use("/azure/openai", addV1, azure); + // Redirect browser requests to the homepage. proxyRouter.get("*", (req, res, next) => { const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); @@ -48,7 +59,8 @@ proxyRouter.get("*", (req, res, next) => { next(); } }); -// Handle 404s. + +// Send a fake client error if user specifies an invalid proxy endpoint. proxyRouter.use((req, res) => { sendErrorToClient({ req, @@ -69,11 +81,3 @@ proxyRouter.use((req, res) => { }); export { proxyRouter as proxyRouter }; - -function addV1(req: Request, res: Response, next: NextFunction) { - // Clients don't consistently use the /v1 prefix so we'll add it for them. - if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) { - req.url = `/v1${req.url}`; - } - next(); -}