18 Commits

Author SHA1 Message Date
user 9e6fd7c24c Implement tools (function calling) for Claude 2024-09-08 00:04:03 +00:00
nai-degen ac92a19946 improves reliability of inference profile detection for AWS keychecker 2024-09-07 17:36:29 -05:00
khanon 96fe974ad0 Use AWS Inference Profiles for higher rate limits (khanon/oai-reverse-proxy!78) 2024-09-01 22:55:07 +00:00
nai-degen 578615fbd2 fixes typo in new Claude system prompt schema 2024-08-30 10:23:57 -05:00
nai-degen 5dc4050e52 disable periodic GCP key rechecks to workaround keychecker bug 2024-08-29 15:25:37 -05:00
nai-degen cf615ee62c applies prettier to GCP checker 2024-08-29 15:15:56 -05:00
nai-degen ee61f9be2b removes unnecessary log from last commit 2024-08-27 23:58:32 -05:00
nai-degen 0c448cb59d fixes azure dalle using wrong rate limit and out-of-spec Retry-After header 2024-08-27 23:53:28 -05:00
nai-degen 51a9ccceb2 supports alternate claude system prompt format 2024-08-27 23:27:20 -05:00
nai-degen ce490efd7d minor adjustments to HMAC signing 2024-08-22 19:54:02 -05:00
nai-degen 5000e59a61 fix for google makersuite prompt validation/transformation 2024-08-22 14:19:48 -05:00
nai-degen d54acad6ad adds support for sonnet 8192 output tokens on anthropic api 2024-08-15 11:55:13 -05:00
nai-degen 5e1fffe07d adds chatgpt-4o-latest 2024-08-15 11:54:42 -05:00
nai-degen f7fd5f00f2 fixes esponse_format schema for mistral la plateforme 2024-08-14 14:41:47 -05:00
nai-degen 6d323f6ea1 do not transform mistral chat prompts to text when using la plateforme 2024-08-14 12:26:27 -05:00
nai-degen 2959ed3f7f fixes aws keychecker not detecting claude 2.1 2024-08-14 10:49:02 -05:00
nai-degen b58e7cb830 always applies Mistral prompt fixes on messages input 2024-08-14 10:48:55 -05:00
khanon f531272b00 Refactor AWS service code and add AWS Mistral support (khanon/oai-reverse-proxy!75) 2024-08-14 04:40:41 +00:00
69 changed files with 2115 additions and 1254 deletions
+98 -11
View File
@@ -11,14 +11,14 @@
"dependencies": {
"@anthropic-ai/tokenizer": "^0.0.4",
"@aws-crypto/sha256-js": "^5.2.0",
"@huggingface/jinja": "^0.3.0",
"@node-rs/argon2": "^1.8.3",
"@smithy/eventstream-codec": "^2.1.3",
"@smithy/eventstream-serde-node": "^2.1.3",
"@smithy/protocol-http": "^3.2.1",
"@smithy/signature-v4": "^2.1.3",
"@smithy/types": "^2.10.1",
"@smithy/util-utf8": "^2.1.1",
"axios": "^1.3.5",
"axios": "^1.7.4",
"better-sqlite3": "^10.0.0",
"check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
@@ -51,6 +51,7 @@
"zod-error": "^1.5.0"
},
"devDependencies": {
"@smithy/types": "^3.3.0",
"@types/better-sqlite3": "^7.6.10",
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
@@ -151,6 +152,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@aws-sdk/types/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@aws-sdk/util-utf8-browser": {
"version": "3.259.0",
"resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz",
@@ -866,6 +878,14 @@
"node": ">=6"
}
},
"node_modules/@huggingface/jinja": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.3.0.tgz",
"integrity": "sha512-GLJzso0M07ZncFkrJMIXVU4os6GFbPocD4g8fMQPMGJubf48FtGOsUORH2rtFdXPIPelz8SLBMn8ZRmOTwXm9Q==",
"engines": {
"node": ">=18"
}
},
"node_modules/@isaacs/cliui": {
"version": "8.0.2",
"resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz",
@@ -1319,6 +1339,17 @@
"tslib": "^2.5.0"
}
},
"node_modules/@smithy/eventstream-codec/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-node": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-node/-/eventstream-serde-node-2.1.3.tgz",
@@ -1332,6 +1363,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-node/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-universal": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-2.1.3.tgz",
@@ -1345,6 +1387,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-universal/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/is-array-buffer": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.1.1.tgz",
@@ -1368,6 +1421,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/protocol-http/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/signature-v4": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-2.1.3.tgz",
@@ -1386,17 +1450,29 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/types": {
"version": "2.10.1",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.10.1.tgz",
"integrity": "sha512-hjQO+4ru4cQ58FluQvKKiyMsFg0A6iRpGm2kqdH8fniyNd2WyanoOsYJfMX/IFLuLxEoW6gnRkNZy1y6fUUhtA==",
"node_modules/@smithy/signature-v4/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.5.0"
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/types": {
"version": "3.3.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-3.3.0.tgz",
"integrity": "sha512-IxvBBCTFDHbVoK7zIxqA1ZOdc4QfM5HM7rGleCuHi7L1wnKv5Pn69xXJQ9hgxH60ZVygH9/JG0jRgtUncE3QUA==",
"dev": true,
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=16.0.0"
}
},
"node_modules/@smithy/util-buffer-from": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.1.1.tgz",
@@ -1432,6 +1508,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/util-middleware/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/util-uri-escape": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-2.1.1.tgz",
@@ -1887,11 +1974,11 @@
}
},
"node_modules/axios": {
"version": "1.6.1",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
"version": "1.7.4",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.4.tgz",
"integrity": "sha512-DukmaFRnY6AzAALSH4J2M3k6PkaC+MfaAGdEERRWcC9q3/TWQwLpHR8ZRLKTdQ3aBDL64EdluRDjJqKw+BPZEw==",
"dependencies": {
"follow-redirects": "^1.15.0",
"follow-redirects": "^1.15.6",
"form-data": "^4.0.0",
"proxy-from-env": "^1.1.0"
}
+3 -2
View File
@@ -20,14 +20,14 @@
"dependencies": {
"@anthropic-ai/tokenizer": "^0.0.4",
"@aws-crypto/sha256-js": "^5.2.0",
"@huggingface/jinja": "^0.3.0",
"@node-rs/argon2": "^1.8.3",
"@smithy/eventstream-codec": "^2.1.3",
"@smithy/eventstream-serde-node": "^2.1.3",
"@smithy/protocol-http": "^3.2.1",
"@smithy/signature-v4": "^2.1.3",
"@smithy/types": "^2.10.1",
"@smithy/util-utf8": "^2.1.1",
"axios": "^1.3.5",
"axios": "^1.7.4",
"better-sqlite3": "^10.0.0",
"check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
@@ -60,6 +60,7 @@
"zod-error": "^1.5.0"
},
"devDependencies": {
"@smithy/types": "^3.3.0",
"@types/better-sqlite3": "^7.6.10",
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
+118
View File
@@ -0,0 +1,118 @@
// uses the aws sdk to sign a request, then uses axios to send it to the bedrock REST API manually
import axios from "axios";
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
const AWS_ACCESS_KEY_ID = process.env.AWS_ACCESS_KEY_ID!;
const AWS_SECRET_ACCESS_KEY = process.env.AWS_SECRET_ACCESS_KEY!;
// Copied from amazon bedrock docs
// List models
// ListFoundationModels
// Service: Amazon Bedrock
// List of Bedrock foundation models that you can use. For more information, see Foundation models in the
// Bedrock User Guide.
// Request Syntax
// GET /foundation-models?
// byCustomizationType=byCustomizationType&byInferenceType=byInferenceType&byOutputModality=byOutputModality&byProvider=byProvider
// HTTP/1.1
// URI Request Parameters
// The request uses the following URI parameters.
// byCustomizationType (p. 38)
// List by customization type.
// Valid Values: FINE_TUNING
// byInferenceType (p. 38)
// List by inference type.
// Valid Values: ON_DEMAND | PROVISIONED
// byOutputModality (p. 38)
// List by output modality type.
// Valid Values: TEXT | IMAGE | EMBEDDING
// byProvider (p. 38)
// A Bedrock model provider.
// Pattern: ^[a-z0-9-]{1,63}$
// Request Body
// The request does not have a request body
// Run inference on a text model
// Send an invoke request to run inference on a Titan Text G1 - Express model. We set the accept
// parameter to accept any content type in the response.
// POST https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke
// -H accept: */*
// -H content-type: application/json
// Payload
// {"inputText": "Hello world"}
// Example response
// Response for the above request.
// -H content-type: application/json
// Payload
// <the model response>
const AMZ_REGION = "us-east-1";
const AMZ_HOST = "invoke-bedrock.us-east-1.amazonaws.com";
async function listModels() {
const httpRequest = new HttpRequest({
method: "GET",
protocol: "https:",
hostname: AMZ_HOST,
path: "/foundation-models",
headers: { ["Host"]: AMZ_HOST },
});
const signedRequest = await signRequest(httpRequest);
const response = await axios.get(
`https://${signedRequest.hostname}${signedRequest.path}`,
{ headers: signedRequest.headers }
);
console.log(response.data);
}
async function invokeModel() {
const model = "anthropic.claude-v1";
const httpRequest = new HttpRequest({
method: "POST",
protocol: "https:",
hostname: AMZ_HOST,
path: `/model/${model}/invoke`,
headers: {
["Host"]: AMZ_HOST,
["accept"]: "*/*",
["content-type"]: "application/json",
},
body: JSON.stringify({
temperature: 0.5,
prompt: "\n\nHuman:Hello world\n\nAssistant:",
max_tokens_to_sample: 10,
}),
});
console.log("httpRequest", httpRequest);
const signedRequest = await signRequest(httpRequest);
const response = await axios.post(
`https://${signedRequest.hostname}${signedRequest.path}`,
signedRequest.body,
{ headers: signedRequest.headers }
);
console.log(response.status);
console.log(response.headers);
console.log(response.data);
console.log("full url", response.request.res.responseUrl);
}
async function signRequest(request: HttpRequest) {
const signer = new SignatureV4({
sha256: Sha256,
credentials: {
accessKeyId: AWS_ACCESS_KEY_ID,
secretAccessKey: AWS_SECRET_ACCESS_KEY,
},
region: AMZ_REGION,
service: "bedrock",
});
return await signer.sign(request, { signingDate: new Date() });
}
// listModels();
// invokeModel();
+3 -2
View File
@@ -17,7 +17,7 @@ import {
} from "../../shared/users/schema";
import { getLastNImages } from "../../shared/file-storage/image-history";
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
import { invalidatePowHmacKey } from "../../user/web/pow-captcha";
import { invalidatePowChallenges } from "../../user/web/pow-captcha";
const router = Router();
@@ -323,7 +323,7 @@ router.post("/maintenance", (req, res) => {
user.disabledReason = "Admin forced expiration.";
userStore.upsertUser(user);
});
invalidatePowHmacKey();
invalidatePowChallenges();
flash.type = "success";
flash.message = `${temps.length} temporary users marked for expiration.`;
break;
@@ -348,6 +348,7 @@ router.post("/maintenance", (req, res) => {
throw new HttpError(400, "Invalid difficulty" + selected);
}
config.powDifficultyLevel = selected;
invalidatePowChallenges();
break;
}
case "generateTempIpReport": {
+13 -30
View File
@@ -415,44 +415,23 @@ export const config: Config = {
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
textModelRateLimit: getEnvWithDefault("TEXT_MODEL_RATE_LIMIT", 4),
imageModelRateLimit: getEnvWithDefault("IMAGE_MODEL_RATE_LIMIT", 4),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 16384),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 32768),
maxContextTokensAnthropic: getEnvWithDefault(
"MAX_CONTEXT_TOKENS_ANTHROPIC",
0
32768
),
maxOutputTokensOpenAI: getEnvWithDefault(
["MAX_OUTPUT_TOKENS_OPENAI", "MAX_OUTPUT_TOKENS"],
400
1024
),
maxOutputTokensAnthropic: getEnvWithDefault(
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
400
1024
),
allowedModelFamilies: getEnvWithDefault(
"ALLOWED_MODEL_FAMILIES",
getDefaultModelFamilies()
),
allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [
"turbo",
"gpt4",
"gpt4-32k",
"gpt4-turbo",
"gpt4o",
"claude",
"claude-opus",
"gemini-flash",
"gemini-pro",
"gemini-ultra",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"mistral-large",
"aws-claude",
"aws-claude-opus",
"gcp-claude",
"gcp-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
"azure-gpt4-turbo",
"azure-gpt4o",
]),
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
rejectMessage: getEnvWithDefault(
"REJECT_MESSAGE",
@@ -540,7 +519,7 @@ function generateSigningKey() {
}
const signingKey = generateSigningKey();
export const COOKIE_SECRET = signingKey;
export const SECRET_SIGNING_KEY = signingKey;
export async function assertConfigIsValid() {
if (process.env.MODEL_RATE_LIMIT !== undefined) {
@@ -801,3 +780,7 @@ function parseCsv(val: string): string[] {
const matches = val.match(regex) || [];
return matches.map((item) => item.replace(/^"|"$/g, "").trim());
}
function getDefaultModelFamilies(): ModelFamily[] {
return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[];
}
+5 -1
View File
@@ -29,6 +29,10 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"mistral-large": "Mistral Large",
"aws-claude": "AWS Claude (Sonnet)",
"aws-claude-opus": "AWS Claude (Opus)",
"aws-mistral-tiny": "AWS Mistral 7B",
"aws-mistral-small": "AWS Mistral Nemo",
"aws-mistral-medium": "AWS Mistral Medium",
"aws-mistral-large": "AWS Mistral Large",
"gcp-claude": "GCP Claude (Sonnet)",
"gcp-claude-opus": "GCP Claude (Opus)",
"azure-turbo": "Azure GPT-3.5 Turbo",
@@ -41,7 +45,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
const converter = new showdown.Converter();
const customGreeting = fs.existsSync("greeting.md")
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
? `<div id="servergreeting">${fs.readFileSync("greeting.md", "utf8")}</div>`
: "";
let infoPageHtml: string | undefined;
let infoPageLastUpdated = 0;
+9
View File
@@ -0,0 +1,9 @@
import { NextFunction, Request, Response } from "express";
export function addV1(req: Request, res: Response, next: NextFunction) {
// Clients don't consistently use the /v1 prefix so we'll add it for them.
if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) {
req.url = `/v1${req.url}`;
}
next();
}
+31 -68
View File
@@ -46,7 +46,7 @@ const getModelsResponse = () => {
"claude-3-haiku-20240307",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-5-sonnet-20240620"
"claude-3-5-sonnet-20240620",
];
const models = claudeVariants.map((id) => ({
@@ -70,7 +70,7 @@ const handleModelRequest: RequestHandler = (_req, res) => {
};
/** Only used for non-streaming requests. */
const anthropicResponseHandler: ProxyResHandlerWithBody = async (
const anthropicBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
@@ -179,6 +179,28 @@ export function transformAnthropicChatResponseToOpenAI(
};
}
/**
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
* model, reassigns it to Claude 3 Sonnet.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
if (!model.startsWith("gpt-")) return;
req.body.model = "claude-3-sonnet-20240229";
}
/**
* If client requests more than 4096 output tokens the request must have a
* particular version header.
* https://docs.anthropic.com/en/release-notes/api#july-15th-2024
*/
function setAnthropicBetaHeader(req: Request) {
const { max_tokens_to_sample } = req.body;
if (max_tokens_to_sample > 4096) {
req.headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15";
}
}
const anthropicProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.anthropic.com",
@@ -189,7 +211,7 @@ const anthropicProxy = createQueueMiddleware({
proxyReq: createOnProxyReqHandler({
pipeline: [addKey, addAnthropicPreamble, finalizeBody],
}),
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
proxyRes: createOnProxyResHandler([anthropicBlockingResponseHandler]),
error: handleProxyError,
},
// Abusing pathFilter to rewrite the paths dynamically.
@@ -213,6 +235,11 @@ const anthropicProxy = createQueueMiddleware({
}),
});
const nativeAnthropicChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "anthropic" },
{ afterTransform: [setAnthropicBetaHeader] }
);
const nativeTextPreprocessor = createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-text",
@@ -268,11 +295,7 @@ anthropicRouter.get("/v1/models", handleModelRequest);
anthropicRouter.post(
"/v1/messages",
ipLimiter,
createPreprocessorMiddleware({
inApi: "anthropic-chat",
outApi: "anthropic-chat",
service: "anthropic",
}),
nativeAnthropicChatPreprocessor,
anthropicProxy
);
// Anthropic text completion endpoint. Translates to Anthropic chat completion
@@ -292,65 +315,5 @@ anthropicRouter.post(
preprocessOpenAICompatRequest,
anthropicProxy
);
// Temporarily force Anthropic Text to Anthropic Chat for frontends which do not
// yet support the new model. Forces claude-3. Will be removed once common
// frontends have been updated.
anthropicRouter.post(
"/v1/:type(sonnet|opus)/:action(complete|messages)",
ipLimiter,
handleAnthropicTextCompatRequest,
createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-chat",
service: "anthropic",
}),
anthropicProxy
);
function handleAnthropicTextCompatRequest(
req: Request,
res: Response,
next: any
) {
const type = req.params.type;
const action = req.params.action;
const alreadyInChatFormat = Boolean(req.body.messages);
const compatModel = `claude-3-${type}-20240229`;
req.log.info(
{ type, inputModel: req.body.model, compatModel, alreadyInChatFormat },
"Handling Anthropic compatibility request"
);
if (action === "messages" || alreadyInChatFormat) {
return sendErrorToClient({
req,
res,
options: {
title: "Unnecessary usage of compatibility endpoint",
message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/anthropic\` proxy endpoint instead.`,
format: "unknown",
statusCode: 400,
reqId: req.id,
obj: {
requested_endpoint: "/anthropic/" + type,
correct_endpoint: "/anthropic",
},
},
});
}
req.body.model = compatModel;
next();
}
/**
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
* model, reassigns it to Claude 3 Sonnet.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
if (!model.startsWith("gpt-")) return;
req.body.model = "claude-3-sonnet-20240229";
}
export const anthropic = anthropicRouter;
+253
View File
@@ -0,0 +1,253 @@
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createPreprocessorMiddleware,
signAwsRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
import {
transformAnthropicChatResponseToAnthropicText,
transformAnthropicChatResponseToOpenAI,
} from "./anthropic";
/** Only used for non-streaming requests. */
const awsResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
switch (`${req.inboundApi}<-${req.outboundApi}`) {
case "openai<-anthropic-text":
req.log.info("Transforming Anthropic Text back to OpenAI format");
newBody = transformAwsTextResponseToOpenAI(body, req);
break;
case "openai<-anthropic-chat":
req.log.info("Transforming AWS Anthropic Chat back to OpenAI format");
newBody = transformAnthropicChatResponseToOpenAI(body);
break;
case "anthropic-text<-anthropic-chat":
req.log.info("Transforming AWS Anthropic Chat back to Text format");
newBody = transformAnthropicChatResponseToAnthropicText(body);
break;
}
// AWS does not always confirm the model in the response, so we have to add it
if (!newBody.model && req.body.model) {
newBody.model = req.body.model;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformAwsTextResponseToOpenAI(
awsBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "aws-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: awsBody.completion?.trim(),
},
finish_reason: awsBody.stop_reason,
index: 0,
},
],
};
}
const awsClaudeProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsResponseHandler]),
error: handleProxyError,
},
}),
});
const nativeTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const textToChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes text completion prompts to aws anthropic-chat if they need translation
* (claude-3 based models do not support the old text completion endpoint).
*/
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
textToChatPreprocessor(req, res, next);
} else {
nativeTextPreprocessor(req, res, next);
}
};
const oaiToAwsTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
* or the new Claude chat completion endpoint, based on the requested model.
*/
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
oaiToAwsChatPreprocessor(req, res, next);
} else {
oaiToAwsTextPreprocessor(req, res, next);
}
};
const awsClaudeRouter = Router();
// Native(ish) Anthropic text completion endpoint.
awsClaudeRouter.post(
"/v1/complete",
ipLimiter,
preprocessAwsTextRequest,
awsClaudeProxy
);
// Native Anthropic chat completion endpoint.
awsClaudeRouter.post(
"/v1/messages",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
),
awsClaudeProxy
);
// OpenAI-to-AWS Anthropic compatibility endpoint.
awsClaudeRouter.post(
"/v1/chat/completions",
ipLimiter,
preprocessOpenAICompatRequest,
awsClaudeProxy
);
/**
* Tries to deal with:
* - frontends sending AWS model names even when they want to use the OpenAI-
* compatible endpoint
* - frontends sending Anthropic model names that AWS doesn't recognize
* - frontends sending OpenAI model names because they expect the proxy to
* translate them
*
* If client sends AWS model ID it will be used verbatim. Otherwise, various
* strategies are used to try to map a non-AWS model name to AWS model ID.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an AWS model, use it as-is
if (model.includes("anthropic.claude")) {
return;
}
// Anthropic model names can look like:
// - claude-v1
// - claude-2.1
// - claude-3-5-sonnet-20240620-v1:0
const pattern =
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
const match = model.match(pattern);
// If there's no match, fallback to Claude v2 as it is most likely to be
// available on AWS.
if (!match) {
req.body.model = `anthropic.claude-v2:1`;
return;
}
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
if (instant) {
req.body.model = "anthropic.claude-instant-v1";
return;
}
const ver = minor ? `${major}.${minor}` : major;
switch (ver) {
case "1":
case "1.0":
req.body.model = "anthropic.claude-v1";
return;
case "2":
case "2.0":
req.body.model = "anthropic.claude-v2";
return;
case "3":
case "3.0":
if (name.includes("opus")) {
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
} else if (name.includes("haiku")) {
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
} else {
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
}
return;
case "3.5":
req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
return;
}
// Fallback to Claude 2.1
req.body.model = `anthropic.claude-v2:1`;
return;
}
export const awsClaude = awsClaudeRouter;
+110
View File
@@ -0,0 +1,110 @@
import { Request } from "express";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { createQueueMiddleware } from "./queue";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
signAwsRequest,
} from "./middleware/request";
import { createProxyMiddleware } from "http-proxy-middleware";
import { logger } from "../logger";
import { handleProxyError } from "./middleware/common";
import { Router } from "express";
import { ipLimiter } from "./rate-limit";
import { detectMistralInputApi, transformMistralTextToMistralChat } from "./mistral-ai";
const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
if (req.inboundApi === "mistral-ai" && req.outboundApi === "mistral-text") {
newBody = transformMistralTextToMistralChat(body);
}
// AWS does not always confirm the model in the response, so we have to add it
if (!newBody.model && req.body.model) {
newBody.model = req.body.model;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
const awsMistralProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]),
error: handleProxyError,
},
}),
});
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an AWS model, use it as-is
if (model.startsWith("mistral.")) {
return;
}
// Mistral 7B Instruct
else if (model.includes("7b")) {
req.body.model = "mistral.mistral-7b-instruct-v0:2";
}
// Mistral 8x7B Instruct
else if (model.includes("8x7b")) {
req.body.model = "mistral.mixtral-8x7b-instruct-v0:1";
}
// Mistral Large (Feb 2024)
else if (model.includes("large-2402")) {
req.body.model = "mistral.mistral-large-2402-v1:0";
}
// Mistral Large 2 (July 2024)
else if (model.includes("large")) {
req.body.model = "mistral.mistral-large-2407-v1:0";
}
// Mistral Small (Feb 2024)
else if (model.includes("small")) {
req.body.model = "mistral.mistral-small-2402-v1:0";
} else {
throw new Error(
`Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock`
);
}
}
const nativeMistralChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" },
{
beforeTransform: [detectMistralInputApi],
afterTransform: [maybeReassignModel],
}
);
const awsMistralRouter = Router();
awsMistralRouter.post(
"/v1/chat/completions",
ipLimiter,
nativeMistralChatPreprocessor,
awsMistralProxy
);
export const awsMistral = awsMistralRouter;
+58 -320
View File
@@ -1,337 +1,75 @@
import { Request, RequestHandler, Response, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
/* Shared code between AWS Claude and AWS Mistral endpoints. */
import { Request, Response, Router } from "express";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createPreprocessorMiddleware,
signAwsRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic";
import { sendErrorToClient } from "./middleware/response/error-generator";
import { addV1 } from "./add-v1";
import { awsClaude } from "./aws-claude";
import { awsMistral } from "./aws-mistral";
import { AwsBedrockKey, keyPool } from "../shared/key-management";
const LATEST_AWS_V2_MINOR_VERSION = "1";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
const awsRouter = Router();
awsRouter.get(["/:vendor?/v1/models", "/:vendor?/models"], handleModelsRequest);
awsRouter.use("/claude", addV1, awsClaude);
awsRouter.use("/mistral", addV1, awsMistral);
const MODELS_CACHE_TTL = 10000;
let modelsCache: Record<string, any> = {};
let modelsCacheTime: Record<string, number> = {};
function handleModelsRequest(req: Request, res: Response) {
if (!config.awsCredentials) return { object: "list", data: [] };
const vendor = req.params.vendor?.length
? req.params.vendor === "claude"
? "anthropic"
: req.params.vendor
: "all";
const cacheTime = modelsCacheTime[vendor] || 0;
if (new Date().getTime() - cacheTime < MODELS_CACHE_TTL) {
return res.json(modelsCache[vendor]);
}
const availableModelIds = new Set<string>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "aws") continue;
(key as AwsBedrockKey).modelIds.forEach((id) => availableModelIds.add(id));
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
const variants = [
const models = [
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
];
const models = variants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "anthropic",
permission: [],
root: "claude",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const awsResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
switch (`${req.inboundApi}<-${req.outboundApi}`) {
case "openai<-anthropic-text":
req.log.info("Transforming Anthropic Text back to OpenAI format");
newBody = transformAwsTextResponseToOpenAI(body, req);
break;
case "openai<-anthropic-chat":
req.log.info("Transforming AWS Anthropic Chat back to OpenAI format");
newBody = transformAnthropicChatResponseToOpenAI(body);
break;
case "anthropic-text<-anthropic-chat":
req.log.info("Transforming AWS Anthropic Chat back to Text format");
newBody = transformAnthropicChatResponseToAnthropicText(body);
break;
}
// AWS does not always confirm the model in the response, so we have to add it
if (!newBody.model && req.body.model) {
newBody.model = req.body.model;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformAwsTextResponseToOpenAI(
awsBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "aws-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: awsBody.completion?.trim(),
},
finish_reason: awsBody.stop_reason,
index: 0,
},
],
};
}
const awsProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsResponseHandler]),
error: handleProxyError,
},
}),
});
const nativeTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const textToChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes text completion prompts to aws anthropic-chat if they need translation
* (claude-3 based models do not support the old text completion endpoint).
*/
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
textToChatPreprocessor(req, res, next);
} else {
nativeTextPreprocessor(req, res, next);
}
};
const oaiToAwsTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
* or the new Claude chat completion endpoint, based on the requested model.
*/
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
oaiToAwsChatPreprocessor(req, res, next);
} else {
oaiToAwsTextPreprocessor(req, res, next);
}
};
const awsRouter = Router();
awsRouter.get("/v1/models", handleModelRequest);
// Native(ish) Anthropic text completion endpoint.
awsRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy);
// Native Anthropic chat completion endpoint.
awsRouter.post(
"/v1/messages",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
),
awsProxy
);
// Temporary force-Claude3 endpoint
awsRouter.post(
"/v1/sonnet/:action(complete|messages)",
ipLimiter,
handleCompatibilityRequest,
createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-chat",
service: "aws",
}),
awsProxy
);
// OpenAI-to-AWS Anthropic compatibility endpoint.
awsRouter.post(
"/v1/chat/completions",
ipLimiter,
preprocessOpenAICompatRequest,
awsProxy
);
/**
* Tries to deal with:
* - frontends sending AWS model names even when they want to use the OpenAI-
* compatible endpoint
* - frontends sending Anthropic model names that AWS doesn't recognize
* - frontends sending OpenAI model names because they expect the proxy to
* translate them
*
* If client sends AWS model ID it will be used verbatim. Otherwise, various
* strategies are used to try to map a non-AWS model name to AWS model ID.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an AWS model, use it as-is
if (model.includes("anthropic.claude")) {
return;
}
// Anthropic model names can look like:
// - claude-v1
// - claude-2.1
// - claude-3-5-sonnet-20240620-v1:0
const pattern =
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
const match = model.match(pattern);
// If there's no match, fallback to Claude v2 as it is most likely to be
// available on AWS.
if (!match) {
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
return;
}
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
if (instant) {
req.body.model = "anthropic.claude-instant-v1";
return;
}
const ver = minor ? `${major}.${minor}` : major;
switch (ver) {
case "1":
case "1.0":
req.body.model = "anthropic.claude-v1";
return;
case "2":
case "2.0":
req.body.model = "anthropic.claude-v2";
return;
case "3":
case "3.0":
if (name.includes("opus")) {
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
} else if (name.includes("haiku")) {
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
} else {
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
}
return;
case "3.5":
req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
return;
}
// Fallback to Claude 2.1
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
return;
}
export function handleCompatibilityRequest(
req: Request,
res: Response,
next: any
) {
const action = req.params.action;
const alreadyInChatFormat = Boolean(req.body.messages);
const compatModel = "anthropic.claude-3-5-sonnet-20240620-v1:0";
req.log.info(
{ inputModel: req.body.model, compatModel, alreadyInChatFormat },
"Handling AWS compatibility request"
);
if (action === "messages" || alreadyInChatFormat) {
return sendErrorToClient({
req,
res,
options: {
title: "Unnecessary usage of compatibility endpoint",
message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/aws/claude\` proxy endpoint instead.`,
format: "unknown",
statusCode: 400,
reqId: req.id,
obj: {
requested_endpoint: "/aws/claude/sonnet",
correct_endpoint: "/aws/claude",
},
},
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
"mistral.mistral-large-2407-v1:0",
"mistral.mistral-small-2402-v1:0",
]
.filter((id) => availableModelIds.has(id))
.map((id) => {
const vendor = id.match(/^(.*)\./)?.[1];
return {
id,
object: "model",
created: new Date().getTime(),
owned_by: vendor,
permission: [],
root: vendor,
parent: null,
};
});
}
req.body.model = compatModel;
next();
modelsCache[vendor] = {
object: "list",
data: models.filter((m) => vendor === "all" || m.root === vendor),
};
modelsCacheTime[vendor] = new Date().getTime();
return res.json(modelsCache[vendor]);
}
export const aws = awsRouter;
+1 -4
View File
@@ -1,6 +1,5 @@
import { Request, RequestHandler, Response, Router } from "express";
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
@@ -17,8 +16,6 @@ import {
createOnProxyResHandler,
} from "./middleware/response";
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
import { sendErrorToClient } from "./middleware/response/error-generator";
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
let modelsCache: any = null;
+1 -1
View File
@@ -152,7 +152,7 @@ googleAIRouter.post(
outApi: "google-ai",
service: "google-ai",
},
{ afterTransform: [maybeReassignModel, setStreamFlag] }
{ beforeTransform: [maybeReassignModel], afterTransform: [setStreamFlag] }
),
googleAIProxy
);
+14 -9
View File
@@ -16,6 +16,7 @@ const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
const ANTHROPIC_MESSAGES_ENDPOINT = "/v1/messages";
const ANTHROPIC_SONNET_COMPAT_ENDPOINT = "/v1/sonnet";
const ANTHROPIC_OPUS_COMPAT_ENDPOINT = "/v1/opus";
const GOOGLE_AI_COMPLETION_ENDPOINT = "/v1beta/models";
export function isTextGenerationRequest(req: Request) {
return (
@@ -27,6 +28,7 @@ export function isTextGenerationRequest(req: Request) {
ANTHROPIC_MESSAGES_ENDPOINT,
ANTHROPIC_SONNET_COMPAT_ENDPOINT,
ANTHROPIC_OPUS_COMPAT_ENDPOINT,
GOOGLE_AI_COMPLETION_ENDPOINT,
].some((endpoint) => req.path.startsWith(endpoint))
);
}
@@ -221,9 +223,12 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
switch (format) {
case "openai":
case "mistral-ai":
// Can be null if the model wants to invoke tools rather than return a
// completion.
return body.choices[0].message.content || "";
// Few possible values:
// - choices[0].message.content
// - choices[0].message with no content if model is invoking a tool
return body.choices?.[0]?.message?.content || "";
case "mistral-text":
return body.outputs?.[0]?.text || "";
case "openai-text":
return body.choices[0].text;
case "anthropic-chat":
@@ -260,22 +265,22 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
}
}
export function getModelFromBody(req: Request, body: Record<string, any>) {
export function getModelFromBody(req: Request, resBody: Record<string, any>) {
const format = req.outboundApi;
switch (format) {
case "openai":
case "openai-text":
return resBody.model;
case "mistral-ai":
return body.model;
case "mistral-text":
case "openai-image":
case "google-ai":
// These formats don't have a model in the response body.
return req.body.model;
case "anthropic-chat":
case "anthropic-text":
// Anthropic confirms the model in the response, but AWS Claude doesn't.
return body.model || req.body.model;
case "google-ai":
// Google doesn't confirm the model in the response.
return req.body.model;
return resBody.model || req.body.model;
default:
assertNever(format);
}
@@ -38,7 +38,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
// translation now reassigns the model earlier in the request pipeline.
case "anthropic-text":
case "anthropic-chat":
assignedKey = keyPool.get("claude-v1", service, needsMultimodal);
case "mistral-ai":
case "mistral-text":
case "google-ai":
assignedKey = keyPool.get(body.model, service);
break;
case "openai-text":
assignedKey = keyPool.get("gpt-3.5-turbo-instruct", service);
@@ -47,10 +50,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
assignedKey = keyPool.get("dall-e-3", service);
break;
case "openai":
case "google-ai":
case "mistral-ai":
throw new Error(
`add-key should not be called for outbound API ${outboundApi}`
`Outbound API ${outboundApi} is not supported for ${inboundApi}`
);
default:
assertNever(outboundApi);
@@ -1,14 +1,16 @@
import { HPMRequestCallback } from "../index";
import { config } from "../../../../config";
import { ForbiddenError } from "../../../../shared/errors";
import { getModelFamilyForRequest } from "../../../../shared/models";
import { HPMRequestCallback } from "../index";
/**
* Ensures the selected model family is enabled by the proxy configuration.
**/
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req, res) => {
*/
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req) => {
const family = getModelFamilyForRequest(req);
if (!config.allowedModelFamilies.includes(family)) {
throw new ForbiddenError(`Model family '${family}' is not enabled on this proxy`);
throw new ForbiddenError(
`Model family '${family}' is not enabled on this proxy`
);
}
};
@@ -84,9 +84,9 @@ async function executePreprocessors(
} catch (error) {
if (error.constructor.name === "ZodError") {
const msg = error?.issues
?.map((issue: ZodIssue) => issue.message)
?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`)
.join("; ");
req.log.info(msg, "Prompt validation failed.");
req.log.warn({ issues: msg }, "Prompt validation failed.");
} else {
req.log.error(error, "Error while executing request preprocessor");
}
@@ -2,7 +2,6 @@ import { RequestPreprocessor } from "../index";
import { countTokens } from "../../../../shared/tokenization";
import { assertNever } from "../../../../shared/utils";
import {
AnthropicChatMessage,
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
@@ -31,10 +30,13 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
}
case "anthropic-chat": {
req.outputTokens = req.body.max_tokens;
const prompt = {
system: req.body.system ?? "",
messages: req.body.messages,
};
let system = req.body.system ?? "";
if (Array.isArray(system)) {
system = system
.map((m: { type: string; text: string }) => m.text)
.join("\n");
}
const prompt = { system, messages: req.body.messages };
result = await countTokens({ req, prompt, service });
break;
}
@@ -50,9 +52,11 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service });
break;
}
case "mistral-ai": {
case "mistral-ai":
case "mistral-text": {
req.outputTokens = req.body.max_tokens;
const prompt: MistralAIChatMessage[] = req.body.messages;
const prompt: string | MistralAIChatMessage[] =
req.body.messages ?? req.body.prompt;
result = await countTokens({ req, prompt, service });
break;
}
@@ -56,8 +56,6 @@ function getPromptFromRequest(req: Request) {
switch (service) {
case "anthropic-chat":
return flattenAnthropicMessages(body.messages);
case "anthropic-text":
return body.prompt;
case "openai":
case "mistral-ai":
return body.messages
@@ -72,8 +70,10 @@ function getPromptFromRequest(req: Request) {
return `${msg.role}: ${text}`;
})
.join("\n\n");
case "anthropic-text":
case "openai-text":
case "openai-image":
case "mistral-text":
return body.prompt;
case "google-ai":
return body.prompt.text;
@@ -1,4 +1,4 @@
import express from "express";
import express, { Request } from "express";
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
@@ -6,8 +6,12 @@ import {
AnthropicV1TextSchema,
AnthropicV1MessagesSchema,
} from "../../../../shared/api-schemas";
import { keyPool } from "../../../../shared/key-management";
import { AwsBedrockKey, keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import {
AWSMistralV1ChatCompletionsSchema,
AWSMistralV1TextCompletionsSchema,
} from "../../../../shared/api-schemas/mistral-ai";
const AMZ_HOST =
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
@@ -29,56 +33,33 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
req.body.prompt = preamble + req.body.prompt;
}
// AWS uses mostly the same parameters as Anthropic, with a few removed params
// and much stricter validation on unused parameters. Rather than treating it
// as a separate schema we will use the anthropic ones and strip the unused
// parameters.
// TODO: This should happen in transform-outbound-payload.ts
let strippedParams: Record<string, unknown>;
if (req.outboundApi === "anthropic-chat") {
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
system: true,
max_tokens: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
})
.strip()
.parse(req.body);
strippedParams.anthropic_version = "bedrock-2023-05-31";
} else {
strippedParams = AnthropicV1TextSchema.pick({
prompt: true,
max_tokens_to_sample: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
})
.strip()
.parse(req.body);
}
const credential = getCredentialParts(req);
const host = AMZ_HOST.replace("%REGION%", credential.region);
// AWS only uses 2023-06-01 and does not actually check this header, but we
// set it so that the stream adapter always selects the correct transformer.
req.headers["anthropic-version"] = "2023-06-01";
// If our key has an inference profile compatible with the requested model,
// we want to use the inference profile instead of the model ID when calling
// InvokeModel as that will give us higher rate limits.
const profile =
(req.key as AwsBedrockKey).inferenceProfileIds.find((p) =>
p.includes(model)
) || model;
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
// with the headers generated by the SDK.
const newRequest = new HttpRequest({
method: "POST",
protocol: "https:",
hostname: host,
path: `/model/${model}/invoke${stream ? "-with-response-stream" : ""}`,
path: `/model/${profile}/invoke${stream ? "-with-response-stream" : ""}`,
headers: {
["Host"]: host,
["content-type"]: "application/json",
},
body: JSON.stringify(strippedParams),
body: JSON.stringify(applyAwsStrictValidation(req)),
});
if (stream) {
@@ -89,7 +70,13 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
const { key, body, inboundApi, outboundApi } = req;
req.log.info(
{ key: key.hash, model: body.model, inboundApi, outboundApi },
{
key: key.hash,
model: body.model,
inferenceProfile: profile,
inboundApi,
outboundApi,
},
"Assigned AWS credentials to request"
);
@@ -128,3 +115,50 @@ async function sign(request: HttpRequest, credential: Credential) {
return signer.sign(request);
}
function applyAwsStrictValidation(req: Request): unknown {
// AWS uses vendor API formats but imposes additional (more strict) validation
// rules, namely that extraneous parameters are not allowed. We will validate
// using the vendor's zod schema but apply `.strip` to ensure that any
// extraneous parameters are removed.
let strippedParams: Record<string, unknown> = {};
switch (req.outboundApi) {
case "anthropic-text":
strippedParams = AnthropicV1TextSchema.pick({
prompt: true,
max_tokens_to_sample: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
})
.strip()
.parse(req.body);
break;
case "anthropic-chat":
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
system: true,
max_tokens: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
tools: true,
tool_choice: true,
})
.strip()
.parse(req.body);
strippedParams.anthropic_version = "bedrock-2023-05-31";
break;
case "mistral-ai":
strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body);
break;
case "mistral-text":
strippedParams = AWSMistralV1TextCompletionsSchema.parse(req.body);
break;
default:
throw new Error("Unexpected outbound API for AWS.");
}
return strippedParams;
}
@@ -24,7 +24,6 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
req.isStreaming = String(stream) === "true";
// TODO: This should happen in transform-outbound-payload.ts
// TODO: Support tools
let strippedParams: Record<string, unknown>;
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
@@ -34,6 +33,8 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
temperature: true,
top_k: true,
top_p: true,
tools: true,
tool_choice: true,
stream: true,
})
.strip()
@@ -1,3 +1,4 @@
import { Request } from "express";
import {
API_REQUEST_VALIDATORS,
API_REQUEST_TRANSFORMERS,
@@ -12,29 +13,33 @@ import { RequestPreprocessor } from "../index";
/** Transforms an incoming request body to one that matches the target API. */
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
const sameService = req.inboundApi === req.outboundApi;
const alreadyTransformed = req.retryCount > 0;
const notTransformable =
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
if (alreadyTransformed || notTransformable) return;
// TODO: this should be an APIFormatTransformer
if (req.inboundApi === "mistral-ai") {
const messages = req.body.messages;
req.body.messages = fixMistralPrompt(messages);
req.log.info(
{ old: messages.length, new: req.body.messages.length },
"Fixed Mistral prompt"
if (alreadyTransformed) {
return;
} else if (notTransformable) {
// This is probably an indication of a bug in the proxy.
const { inboundApi, outboundApi, method, path } = req;
req.log.warn(
{ inboundApi, outboundApi, method, path },
"`transformOutboundPayload` called on a non-transformable request."
);
return;
}
if (sameService) {
applyMistralPromptFixes(req);
// Native prompts are those which were already provided by the client in the
// target API format. We don't need to transform them.
const isNativePrompt = req.inboundApi === req.outboundApi;
if (isNativePrompt) {
const result = API_REQUEST_VALIDATORS[req.inboundApi].safeParse(req.body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body: req.body },
"Request validation failed"
"Native prompt request validation failed."
);
throw result.error;
}
@@ -42,11 +47,12 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
return;
}
// Prompt requires translation from one API format to another.
const transformation = `${req.inboundApi}->${req.outboundApi}` as const;
const transFn = API_REQUEST_TRANSFORMERS[transformation];
if (transFn) {
req.log.info({ transformation }, "Transforming request");
req.log.info({ transformation }, "Transforming request...");
req.body = await transFn(req);
return;
}
@@ -55,3 +61,36 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
`${transformation} proxying is not supported. Make sure your client is configured to send requests in the correct format and to the correct endpoint.`
);
};
// handles weird cases that don't fit into our abstractions
function applyMistralPromptFixes(req: Request): void {
if (req.inboundApi === "mistral-ai") {
// Mistral Chat is very similar to OpenAI but not identical and many clients
// don't properly handle the differences. We will try to validate the
// mistral prompt and try to fix it if it fails. It will be re-validated
// after this function returns.
const result = API_REQUEST_VALIDATORS["mistral-ai"].parse(req.body);
req.body.messages = fixMistralPrompt(result.messages);
req.log.info(
{ n: req.body.messages.length, prev: result.messages.length },
"Applied Mistral chat prompt fixes."
);
// If the prompt relies on `prefix: true` for the last message, we need to
// convert it to a text completions request because AWS Mistral support for
// this feature is broken.
// On Mistral La Plateforme, we can't do this because they don't expose
// a text completions endpoint.
const { messages } = req.body;
const lastMessage = messages && messages[messages.length - 1];
if (lastMessage?.role === "assistant" && req.service === "aws") {
// enable prefix if client forgot, otherwise the template will insert an
// eos token which is very unlikely to be what the client wants.
lastMessage.prefix = true;
req.outboundApi = "mistral-text";
req.log.info(
"Native Mistral chat prompt relies on assistant message prefix. Converting to text completions request."
);
}
}
}
@@ -38,6 +38,7 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
proxyMax = GOOGLE_AI_MAX_CONTEXT;
break;
case "mistral-ai":
case "mistral-text":
proxyMax = MISTRAL_AI_MAX_CONTENT;
break;
case "openai-image":
@@ -57,6 +58,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 16384;
} else if (model.match(/^gpt-4o/)) {
modelMax = 128000;
} else if (model.match(/^chatgpt-4o/)) {
modelMax = 128000;
} else if (model.match(/gpt-4-turbo(-\d{4}-\d{2}-\d{2})?$/)) {
modelMax = 131072;
} else if (model.match(/gpt-4-turbo(-preview)?$/)) {
@@ -28,6 +28,7 @@ export const validateVision: RequestPreprocessor = async (req) => {
case "anthropic-text":
case "google-ai":
case "mistral-ai":
case "mistral-text":
case "openai-image":
case "openai-text":
return;
@@ -189,6 +189,11 @@ export function buildSpoofedCompletion({
},
],
};
case "mistral-text":
return {
outputs: [{ text: content, stop_reason: title }],
model,
}
case "openai-text":
return {
id: "error-" + id,
@@ -267,6 +272,11 @@ export function buildSpoofedSSE({
choices: [{ delta: { content }, index: 0, finish_reason: title }],
};
break;
case "mistral-text":
event = {
outputs: [{ text: content, stop_reason: title }],
};
break;
case "openai-text":
event = {
id: "cmpl-" + id,
@@ -22,18 +22,19 @@ import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
const pipelineAsync = promisify(pipeline);
/**
* `handleStreamedResponse` consumes and transforms a streamed response from the
* upstream service, forwarding events to the client in their requested format.
* `handleStreamedResponse` consumes a streamed response from the upstream API,
* decodes chunk-by-chunk into a stream of events, transforms those events into
* the client's requested format, and forwards the result to the client.
*
* After the entire stream has been consumed, it resolves with the full response
* body so that subsequent middleware in the chain can process it as if it were
* a non-streaming response.
* a non-streaming response (to count output tokens, track usage, etc).
*
* In the event of an error, the request's streaming flag is unset and the non-
* streaming response handler is called instead.
*
* If the error is retryable, that handler will re-enqueue the request and also
* reset the streaming flag. Unfortunately the streaming flag is set and unset
* in multiple places, so it's hard to keep track of.
* In the event of an error, the request's streaming flag is unset and the
* request is bounced back to the non-streaming response handler. If the error
* is retryable, that handler will re-enqueue the request and also reset the
* streaming flag. Unfortunately the streaming flag is set and unset in multiple
* places, so it's hard to keep track of.
*/
export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes,
@@ -70,13 +71,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
logger: req.log,
};
// Decoder turns the raw response stream into a stream of events in some
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
// While the request is streaming, aggregator collects all events so that we
// can compile them into a single response object and publish that to the
// remaining middleware. Because we have an OpenAI transformer for every
// supported format, EventAggregator always consumes OpenAI events so that we
// only have to write one aggregator (OpenAI input) for each output format.
const aggregator = new EventAggregator(req);
// Decoder reads from the raw response buffer and produces a stream of
// discrete events in some format (text/event-stream, vnd.amazon.event-stream,
// streaming JSON, etc).
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
// Adapter transforms the decoded events into server-sent events.
// Adapter consumes the decoded events and produces server-sent events so we
// have a standard event format for the client and to translate between API
// message formats.
const adapter = new SSEStreamAdapter(streamOptions);
// Aggregator compiles all events into a single response object.
const aggregator = new EventAggregator({ format: req.outboundApi });
// Transformer converts server-sent events from one vendor's API message
// format to another.
const transformer = new SSEMessageTransformer({
+13 -12
View File
@@ -11,7 +11,8 @@ import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
import {
AnthropicChatMessage,
flattenAnthropicMessages, GoogleAIChatMessage,
flattenAnthropicMessages,
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../../shared/api-schemas";
@@ -74,8 +75,16 @@ const getPromptForRequest = (
case "mistral-ai":
return req.body.messages;
case "anthropic-chat":
return { system: req.body.system, messages: req.body.messages };
let system = req.body.system;
if (Array.isArray(system)) {
system = system
.map((m: { type: string; text: string }) => m.text)
.join("\n");
}
return { system, messages: req.body.messages };
case "openai-text":
case "anthropic-text":
case "mistral-text":
return req.body.prompt;
case "openai-image":
return {
@@ -85,8 +94,6 @@ const getPromptForRequest = (
quality: req.body.quality,
revisedPrompt: responseBody.data[0].revised_prompt,
};
case "anthropic-text":
return req.body.prompt;
case "google-ai":
return { contents: req.body.contents };
default:
@@ -113,9 +120,7 @@ const flattenMessages = (
if (isGoogleAIChatPrompt(val)) {
return val.contents
.map(({ parts, role }) => {
const text = parts
.map((p) => p.text)
.join("\n");
const text = parts.map((p) => p.text).join("\n");
return `${role}: ${text}`;
})
.join("\n");
@@ -143,11 +148,7 @@ const flattenMessages = (
function isGoogleAIChatPrompt(
val: unknown
): val is { contents: GoogleAIChatMessage[] } {
return (
typeof val === "object" &&
val !== null &&
"contents" in val
);
return typeof val === "object" && val !== null && "contents" in val;
}
function isAnthropicChatPrompt(
@@ -0,0 +1,39 @@
import { OpenAIChatCompletionStreamEvent } from "../index";
export type MistralChatCompletionResponse = {
choices: {
index: number;
message: { role: string; content: string };
finish_reason: string | null;
}[];
};
/**
* Given a list of OpenAI chat completion events, compiles them into a single
* finalized Mistral chat completion response so that non-streaming middleware
* can operate on it as if it were a blocking response.
*/
export function mergeEventsForMistralChat(
events: OpenAIChatCompletionStreamEvent[]
): MistralChatCompletionResponse {
let merged: MistralChatCompletionResponse = {
choices: [
{ index: 0, message: { role: "", content: "" }, finish_reason: "" },
],
};
merged = events.reduce((acc, event, i) => {
// The first event will only contain role assignment and response metadata
if (i === 0) {
acc.choices[0].message.role = event.choices[0].delta.role ?? "assistant";
return acc;
}
acc.choices[0].finish_reason = event.choices[0].finish_reason ?? "";
if (event.choices[0].delta.content) {
acc.choices[0].message.content += event.choices[0].delta.content;
}
return acc;
}, merged);
return merged;
}
@@ -0,0 +1,33 @@
import { OpenAIChatCompletionStreamEvent } from "../index";
export type MistralTextCompletionResponse = {
outputs: {
text: string;
stop_reason: string | null;
}[];
};
/**
* Given a list of OpenAI chat completion events, compiles them into a single
* finalized Mistral text completion response so that non-streaming middleware
* can operate on it as if it were a blocking response.
*/
export function mergeEventsForMistralText(
events: OpenAIChatCompletionStreamEvent[]
): MistralTextCompletionResponse {
let merged: MistralTextCompletionResponse = {
outputs: [{ text: "", stop_reason: "" }],
};
merged = events.reduce((acc, event, i) => {
// The first event will only contain role assignment and response metadata
if (i === 0) {
return acc;
}
acc.outputs[0].text += event.choices[0].delta.content ?? "";
acc.outputs[0].stop_reason = event.choices[0].finish_reason ?? "";
return acc;
}, merged);
return merged;
}
@@ -24,7 +24,7 @@ export function getAwsEventStreamDecoder(params: {
if (eventType === "chunk") {
result = input[eventType];
} else {
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
// AWS unmarshaller treats non-chunk events (errors and exceptions) oddly.
result = { [eventType]: input[eventType] } as any;
}
return result;
@@ -1,3 +1,4 @@
import express from "express";
import { APIFormat } from "../../../../shared/key-management";
import { assertNever } from "../../../../shared/utils";
import {
@@ -6,8 +7,13 @@ import {
mergeEventsForAnthropicText,
mergeEventsForOpenAIChat,
mergeEventsForOpenAIText,
mergeEventsForMistralChat,
mergeEventsForMistralText,
AnthropicV2StreamEvent,
OpenAIChatCompletionStreamEvent,
mistralAIToOpenAI,
MistralAIStreamEvent,
MistralChatCompletionEvent,
} from "./index";
/**
@@ -15,45 +21,70 @@ import {
* compiles them into a single finalized response for downstream middleware.
*/
export class EventAggregator {
private readonly format: APIFormat;
private readonly model: string;
private readonly requestFormat: APIFormat;
private readonly responseFormat: APIFormat;
private readonly events: OpenAIChatCompletionStreamEvent[];
constructor({ format }: { format: APIFormat }) {
constructor({ body, inboundApi, outboundApi }: express.Request) {
this.events = [];
this.format = format;
this.requestFormat = inboundApi;
this.responseFormat = outboundApi;
this.model = body.model;
}
addEvent(event: OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent) {
addEvent(
event:
| OpenAIChatCompletionStreamEvent
| AnthropicV2StreamEvent
| MistralAIStreamEvent
) {
if (eventIsOpenAIEvent(event)) {
this.events.push(event);
} else {
// horrible special case. previously all transformers' target format was
// openai, so the event aggregator could conveniently assume all incoming
// events were in openai format.
// now we have added anthropic-chat-to-text, so aggregator needs to know
// how to collapse events from two formats.
// because that is annoying, we will simply transform anthropic events to
// openai (even if the client didn't ask for openai) so we don't have to
// write aggregation logic for anthropic chat (which is also a troublesome
// stateful format).
const openAIEvent = anthropicV2ToOpenAI({
data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`,
lastPosition: -1,
index: 0,
fallbackId: event.log_id || "event-aggregator-fallback",
fallbackModel: event.model || "claude-3-fallback",
});
if (openAIEvent.event) {
this.events.push(openAIEvent.event);
// now we have added some transformers that convert between non-openai
// formats, so aggregator needs to know how to collapse for more than
// just openai.
// because writing aggregation logic for every possible output format is
// annoying, we will just transform any non-openai output events to openai
// format (even if the client did not request openai at all) so that we
// still only need to write aggregators for openai SSEs.
let openAIEvent: OpenAIChatCompletionStreamEvent | undefined;
switch (this.requestFormat) {
case "anthropic-text":
assertIsAnthropicV2Event(event);
openAIEvent = anthropicV2ToOpenAI({
data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`,
lastPosition: -1,
index: 0,
fallbackId: event.log_id || "fallback-" + Date.now(),
fallbackModel: event.model || this.model || "fallback-claude-3",
})?.event;
break;
case "mistral-ai":
assertIsMistralChatEvent(event);
openAIEvent = mistralAIToOpenAI({
data: `data: ${JSON.stringify(event)}\n\n`,
lastPosition: -1,
index: 0,
fallbackId: "fallback-" + Date.now(),
fallbackModel: this.model || "fallback-mistral",
})?.event;
break;
}
if (openAIEvent) {
this.events.push(openAIEvent);
}
}
}
getFinalResponse() {
switch (this.format) {
switch (this.responseFormat) {
case "openai":
case "google-ai":
case "mistral-ai":
case "google-ai": // TODO: this is probably wrong now that we support native Google Makersuite prompts
return mergeEventsForOpenAIChat(this.events);
case "openai-text":
return mergeEventsForOpenAIText(this.events);
@@ -61,10 +92,16 @@ export class EventAggregator {
return mergeEventsForAnthropicText(this.events);
case "anthropic-chat":
return mergeEventsForAnthropicChat(this.events);
case "mistral-ai":
return mergeEventsForMistralChat(this.events);
case "mistral-text":
return mergeEventsForMistralText(this.events);
case "openai-image":
throw new Error(`SSE aggregation not supported for ${this.format}`);
throw new Error(
`SSE aggregation not supported for ${this.responseFormat}`
);
default:
assertNever(this.format);
assertNever(this.responseFormat);
}
}
@@ -78,3 +115,17 @@ function eventIsOpenAIEvent(
): event is OpenAIChatCompletionStreamEvent {
return event?.object === "chat.completion.chunk";
}
function assertIsAnthropicV2Event(event: any): asserts event is AnthropicV2StreamEvent {
if (!event?.completion) {
throw new Error(`Bad event for Anthropic V2 SSE aggregation`);
}
}
function assertIsMistralChatEvent(
event: any
): asserts event is MistralChatCompletionEvent {
if (!event?.choices) {
throw new Error(`Bad event for Mistral SSE aggregation`);
}
}
@@ -7,6 +7,25 @@ export type SSEResponseTransformArgs<S = Record<string, any>> = {
state?: S;
};
export type MistralChatCompletionEvent = {
choices: {
index: number;
message: { role: string; content: string };
stop_reason: string | null;
}[];
};
export type MistralTextCompletionEvent = {
outputs: { text: string; stop_reason: string | null }[];
};
export type MistralAIStreamEvent = {
"amazon-bedrock-invocationMetrics"?: {
inputTokenCount: number;
outputTokenCount: number;
invocationLatency: number;
firstByteLatency: number;
};
} & (MistralChatCompletionEvent | MistralTextCompletionEvent);
export type AnthropicV2StreamEvent = {
log_id?: string;
model?: string;
@@ -41,8 +60,12 @@ export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { anthropicChatToAnthropicV2 } from "./transformers/anthropic-chat-to-anthropic-v2";
export { anthropicChatToOpenAI } from "./transformers/anthropic-chat-to-openai";
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
export { mistralAIToOpenAI } from "./transformers/mistral-ai-to-openai";
export { mistralTextToMistralChat } from "./transformers/mistral-text-to-mistral-chat";
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropicText } from "./aggregators/anthropic-text";
export { mergeEventsForAnthropicChat } from "./aggregators/anthropic-chat";
export { mergeEventsForMistralChat } from "./aggregators/mistral-chat";
export { mergeEventsForMistralText } from "./aggregators/mistral-text";
@@ -11,8 +11,11 @@ import {
googleAIToOpenAI,
OpenAIChatCompletionStreamEvent,
openAITextToOpenAIChat,
mistralAIToOpenAI,
mistralTextToMistralChat,
passthroughToOpenAI,
StreamingCompletionTransformer,
MistralChatCompletionEvent,
} from "./index";
type SSEMessageTransformerOptions = TransformOptions & {
@@ -35,7 +38,9 @@ export class SSEMessageTransformer extends Transform {
private readonly inputFormat: APIFormat;
private readonly transformFn: StreamingCompletionTransformer<
// TODO: Refactor transformers to not assume only OpenAI events as output
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
| OpenAIChatCompletionStreamEvent
| AnthropicV2StreamEvent
| MistralChatCompletionEvent
>;
private readonly log;
private readonly fallbackId: string;
@@ -121,16 +126,17 @@ function eventIsOpenAIEvent(
function getTransformer(
responseApi: APIFormat,
version?: string,
// There's only one case where we're not transforming back to OpenAI, which is
// Anthropic Chat response -> Anthropic Text request. This parameter is only
// used for that case.
// In most cases, we are transforming back to OpenAI. Some responses can be
// translated between two non-OpenAI formats, eg Anthropic Chat -> Anthropic
// Text, or Mistral Text -> Mistral Chat.
requestApi: APIFormat = "openai"
): StreamingCompletionTransformer<
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
| OpenAIChatCompletionStreamEvent
| AnthropicV2StreamEvent
| MistralChatCompletionEvent
> {
switch (responseApi) {
case "openai":
case "mistral-ai":
return passthroughToOpenAI;
case "openai-text":
return openAITextToOpenAIChat;
@@ -140,10 +146,16 @@ function getTransformer(
: anthropicV2ToOpenAI;
case "anthropic-chat":
return requestApi === "anthropic-text"
? anthropicChatToAnthropicV2
? anthropicChatToAnthropicV2 // User's legacy text prompt was converted to chat, and response must be converted back to text
: anthropicChatToOpenAI;
case "google-ai":
return googleAIToOpenAI;
case "mistral-ai":
return mistralAIToOpenAI;
case "mistral-text":
return requestApi === "mistral-ai"
? mistralTextToMistralChat // User's chat request was converted to text, and response must be converted back to chat
: mistralAIToOpenAI;
case "openai-image":
throw new Error(`SSE transformation not supported for ${responseApi}`);
default:
@@ -55,8 +55,10 @@ export class SSEStreamAdapter extends Transform {
if ("completion" in eventObj) {
return ["event: completion", `data: ${event}`].join(`\n`);
} else {
} else if (eventObj.type) {
return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`);
} else {
return `data: ${event}`;
}
}
// noinspection FallThroughInSwitchStatementJS -- non-JSON data is unexpected
@@ -0,0 +1,76 @@
import { logger } from "../../../../../logger";
import { MistralAIStreamEvent, SSEResponseTransformArgs } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
const log = logger.child({
module: "sse-transformer",
transformer: "mistral-ai-to-openai",
});
export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
if ("choices" in completionEvent) {
const newChatEvent = {
id: params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: completionEvent.choices[0].index,
delta: { content: completionEvent.choices[0].message.content },
finish_reason: completionEvent.choices[0].stop_reason,
},
],
};
return { position: -1, event: newChatEvent };
} else if ("outputs" in completionEvent) {
const newTextEvent = {
id: params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: 0,
delta: { content: completionEvent.outputs[0].text },
finish_reason: completionEvent.outputs[0].stop_reason,
},
],
};
return { position: -1, event: newTextEvent };
}
// should never happen
return { position: -1 };
};
function asCompletion(event: ServerSentEvent): MistralAIStreamEvent | null {
try {
const parsed = JSON.parse(event.data);
if (
(Array.isArray(parsed.choices) &&
parsed.choices[0].message !== undefined) ||
(Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined)
) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid data event");
}
return null;
}
@@ -0,0 +1,63 @@
import {
MistralChatCompletionEvent,
MistralTextCompletionEvent,
StreamingCompletionTransformer,
} from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "mistral-text-to-mistral-chat",
});
/**
* Transforms an incoming Mistral Text SSE to an equivalent Mistral Chat SSE.
* This is generally used when a client sends a Mistral Chat prompt, but we
* convert it to Mistral Text before sending it to the API to work around
* some bugs in Mistral/AWS prompt templating. In these cases we need to convert
* the response back to Mistral Chat.
*/
export const mistralTextToMistralChat: StreamingCompletionTransformer<
MistralChatCompletionEvent
> = (params) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data) {
return { position: -1 };
}
const textCompletion = asTextCompletion(rawEvent);
if (!textCompletion) {
return { position: -1 };
}
const chatEvent: MistralChatCompletionEvent = {
choices: [
{
index: 0,
message: { role: "assistant", content: textCompletion.outputs[0].text },
stop_reason: textCompletion.outputs[0].stop_reason,
},
],
};
return { position: -1, event: chatEvent };
};
function asTextCompletion(
event: ServerSentEvent
): MistralTextCompletionEvent | null {
try {
const parsed = JSON.parse(event.data);
if (Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error: any) {
log.warn({ error: error.stack, event }, "Received invalid data event");
}
return null;
}
+49 -8
View File
@@ -1,4 +1,4 @@
import { RequestHandler, Router } from "express";
import express, { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
@@ -21,6 +21,7 @@ import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { BadRequestError } from "../shared/errors";
// Mistral can't settle on a single naming scheme and deprecates models within
// months of releasing them so this list is hard to keep up to date. 2024-07-28
@@ -61,7 +62,7 @@ export const KNOWN_MISTRAL_AI_MODELS = [
"mistral-medium-latest",
"mistral-medium-2312",
"mistral-tiny",
"mistral-tiny-2312"
"mistral-tiny-2312",
];
let modelsCache: any = null;
@@ -108,9 +109,24 @@ const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
throw new Error("Expected body to be an object");
}
res.status(200).json({ ...body, proxy: body.proxy });
let newBody = body;
if (req.inboundApi === "mistral-text" && req.outboundApi === "mistral-ai") {
newBody = transformMistralTextToMistralChat(body);
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
export function transformMistralTextToMistralChat(textBody: any) {
return {
...textBody,
choices: [
{ message: { content: textBody.outputs[0].text, role: "assistant" } },
],
outputs: undefined,
};
}
const mistralAIProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.mistral.ai",
@@ -133,12 +149,37 @@ mistralAIRouter.get("/v1/models", handleModelRequest);
mistralAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "mistral-ai",
outApi: "mistral-ai",
service: "mistral-ai",
}),
createPreprocessorMiddleware(
{
inApi: "mistral-ai",
outApi: "mistral-ai",
service: "mistral-ai",
},
{ beforeTransform: [detectMistralInputApi] }
),
mistralAIProxy
);
/**
* We can't determine if a request is Mistral text or chat just from the path
* because they both use the same endpoint. We need to check the request body
* for either `messages` or `prompt`.
* @param req
*/
export function detectMistralInputApi(req: Request) {
const { messages, prompt } = req.body;
if (messages) {
req.inboundApi = "mistral-ai";
req.outboundApi = "mistral-ai";
} else if (prompt && req.service === "mistral-ai") {
// Mistral La Plateforme doesn't expose a text completions endpoint.
throw new BadRequestError(
"Mistral (via La Plateforme API) does not support text completions. This format is only supported on Mistral via the AWS API."
);
} else if (prompt && req.service === "aws") {
req.inboundApi = "mistral-text";
req.outboundApi = "mistral-text";
}
}
export const mistralAI = mistralAIRouter;
+2
View File
@@ -35,6 +35,8 @@ export const KNOWN_OPENAI_MODELS = [
// GPT4o Mini
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
// GPT4o (ChatGPT)
"chatgpt-4o-latest",
// GPT4 Turbo (superceded by GPT4o)
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09", // gpt4-turbo stable, with vision
+12 -48
View File
@@ -22,7 +22,7 @@ import {
} from "../shared/models";
import { initializeSseStream } from "../shared/streaming";
import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { getUniqueIps } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
import { handleProxyError } from "./middleware/common";
import { sendErrorToClient } from "./middleware/response/error-generator";
@@ -31,7 +31,9 @@ const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
/** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = parseInt(process.env.USER_CONCURRENCY_LIMIT ?? "1");
const USER_CONCURRENCY_LIMIT = parseInt(
process.env.USER_CONCURRENCY_LIMIT ?? "1"
);
/** Maximum number of queue slots for Agnai.chat requests. */
const AGNAI_CONCURRENCY_LIMIT = USER_CONCURRENCY_LIMIT * 5;
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
@@ -58,39 +60,20 @@ const QUEUE_JOIN_TIMEOUT = 5000;
function getIdentifier(req: Request) {
if (req.user) return req.user.token;
if (req.risuToken) return req.risuToken;
if (isFromSharedIp(req)) return "shared-ip";
// if (isFromSharedIp(req)) return "shared-ip";
return req.ip;
}
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
getIdentifier(queued) === getIdentifier(incoming);
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
async function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
let isGuest = req.user?.token === undefined;
// Requests from shared IP addresses such as Agnai.chat are exempt from IP-
// based rate limiting but can only occupy a certain number of slots in the
// queue. Authenticated users always get a single spot in the queue.
const isSharedIp = isFromSharedIp(req);
const maxConcurrentQueuedRequests =
isGuest && isSharedIp ? AGNAI_CONCURRENCY_LIMIT : USER_CONCURRENCY_LIMIT;
if (enqueuedRequestCount >= maxConcurrentQueuedRequests) {
if (isSharedIp) {
// Re-enqueued requests are not counted towards the limit since they
// already made it through the queue once.
if (req.retryCount === 0) {
throw new TooManyRequestsError(
"Too many agnai.chat requests are already queued"
);
}
} else {
throw new TooManyRequestsError(
"Your IP or user token already has another request in the queue."
);
}
if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT) {
throw new TooManyRequestsError(
"Your IP or user token already has another request in the queue."
);
}
// shitty hack to remove hpm's event listeners on retried requests
@@ -146,19 +129,7 @@ export async function reenqueueRequest(req: Request) {
}
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue
.filter((req) => getModelFamilyForRequest(req) === partition)
.sort((a, b) => {
// Certain requests are exempted from IP-based rate limiting because they
// come from a shared IP address. To prevent these requests from starving
// out other requests during periods of high traffic, we sort them to the
// end of the queue.
const aIsExempted = isFromSharedIp(a);
const bIsExempted = isFromSharedIp(b);
if (aIsExempted && !bIsExempted) return 1;
if (!aIsExempted && bIsExempted) return -1;
return 0;
});
return queue.filter((req) => getModelFamilyForRequest(req) === partition);
}
export function dequeue(partition: ModelFamily): Request | undefined {
@@ -261,7 +232,6 @@ let waitTimes: {
partition: ModelFamily;
start: number;
end: number;
isDeprioritized: boolean;
}[] = [];
/** Adds a successful request to the list of wait times. */
@@ -270,7 +240,6 @@ export function trackWaitTime(req: Request) {
partition: getModelFamilyForRequest(req),
start: req.startTime!,
end: req.queueOutTime ?? Date.now(),
isDeprioritized: isFromSharedIp(req),
});
}
@@ -296,8 +265,7 @@ function calculateWaitTime(partition: ModelFamily) {
.filter((wait) => {
const isSamePartition = wait.partition === partition;
const isRecent = now - wait.end < 300 * 1000;
const isNormalPriority = !wait.isDeprioritized;
return isSamePartition && isRecent && isNormalPriority;
return isSamePartition && isRecent;
})
.map((wait) => wait.end - wait.start);
const recentAverage = recentWaits.length
@@ -311,11 +279,7 @@ function calculateWaitTime(partition: ModelFamily) {
);
const currentWaits = queue
.filter((req) => {
const isSamePartition = getModelFamilyForRequest(req) === partition;
const isNormalPriority = !isFromSharedIp(req);
return isSamePartition && isNormalPriority;
})
.filter((req) => getModelFamilyForRequest(req) === partition)
.map((req) => now - req.startTime!);
const longestCurrentWait = Math.max(...currentWaits, 0);
+15 -32
View File
@@ -1,14 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { config } from "../config";
export const SHARED_IP_ADDRESSES = new Set([
// Agnai.chat
"157.230.249.32", // old
"157.245.148.56",
"174.138.29.50",
"209.97.162.44",
]);
const ONE_MINUTE_MS = 60 * 1000;
type Timestamp = number;
@@ -20,7 +12,10 @@ const exemptedRequests: Timestamp[] = [];
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
attempt > now - ONE_MINUTE_MS;
const getTryAgainInMs = (ip: string, type: "text" | "image") => {
/**
* Returns duration in seconds to wait before retrying for Retry-After header.
*/
const getRetryAfter = (ip: string, type: "text" | "image") => {
const now = Date.now();
const attempts = lastAttempts.get(ip) || [];
const validAttempts = attempts.filter(isRecentAttempt(now));
@@ -29,7 +24,7 @@ const getTryAgainInMs = (ip: string, type: "text" | "image") => {
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
if (validAttempts.length >= limit) {
return validAttempts[0] - now + ONE_MINUTE_MS;
return (validAttempts[0] - now + ONE_MINUTE_MS) / 1000;
} else {
lastAttempts.set(ip, [...validAttempts, now]);
return 0;
@@ -96,22 +91,11 @@ export const ipLimiter = async (
if (!textLimit && !imageLimit) return next();
if (req.user?.type === "special") return next();
// Exempts Agnai.chat from IP-based rate limiting because its IPs are shared
// by many users. Instead, the request queue will limit the number of such
// requests that may wait in the queue at a time, and sorts them to the end to
// let individual users go first.
if (SHARED_IP_ADDRESSES.has(req.ip)) {
exemptedRequests.push(Date.now());
req.log.info(
{ ip: req.ip, recentExemptions: exemptedRequests.length },
"Exempting Agnai request from rate limiting."
);
return next();
}
const type = (req.baseUrl + req.path).includes("openai-image")
? "image"
: "text";
const path = req.baseUrl + req.path;
const type =
path.includes("openai-image") || path.includes("images/generations")
? "image"
: "text";
const limit = type === "image" ? imageLimit : textLimit;
// If user is authenticated, key rate limiting by their token. Otherwise, key
@@ -123,15 +107,14 @@ export const ipLimiter = async (
res.set("X-RateLimit-Remaining", remaining.toString());
res.set("X-RateLimit-Reset", reset.toString());
const tryAgainInMs = getTryAgainInMs(rateLimitKey, type);
if (tryAgainInMs > 0) {
res.set("Retry-After", tryAgainInMs.toString());
const retryAfterTime = getRetryAfter(rateLimitKey, type);
if (retryAfterTime > 0) {
const waitSec = Math.ceil(retryAfterTime).toString();
res.set("Retry-After", waitSec);
res.status(429).json({
error: {
type: "proxy_rate_limited",
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil(
tryAgainInMs / 1000
)} seconds.`,
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${waitSec} seconds.`,
},
});
} else {
+24 -20
View File
@@ -1,44 +1,55 @@
import express, { Request, Response, NextFunction } from "express";
import { gatekeeper } from "./gatekeeper";
import { checkRisuToken } from "./check-risu-token";
import { openai } from "./openai";
import { openaiImage } from "./openai-image";
import express from "express";
import { addV1 } from "./add-v1";
import { anthropic } from "./anthropic";
import { aws } from "./aws";
import { azure } from "./azure";
import { checkRisuToken } from "./check-risu-token";
import { gatekeeper } from "./gatekeeper";
import { gcp } from "./gcp";
import { googleAI } from "./google-ai";
import { mistralAI } from "./mistral-ai";
import { aws } from "./aws";
import { gcp } from "./gcp";
import { azure } from "./azure";
import { openai } from "./openai";
import { openaiImage } from "./openai-image";
import { sendErrorToClient } from "./middleware/response/error-generator";
const proxyRouter = express.Router();
// Remove `expect: 100-continue` header from requests due to incompatibility
// with node-http-proxy.
proxyRouter.use((req, _res, next) => {
if (req.headers.expect) {
// node-http-proxy does not like it when clients send `expect: 100-continue`
// and will stall. none of the upstream APIs use this header anyway.
delete req.headers.expect;
}
next();
});
// Apply body parsers.
proxyRouter.use(
express.json({ limit: "100mb" }),
express.urlencoded({ extended: true, limit: "100mb" })
);
// Apply auth/rate limits.
proxyRouter.use(gatekeeper);
proxyRouter.use(checkRisuToken);
// Initialize request queue metadata.
proxyRouter.use((req, _res, next) => {
req.startTime = Date.now();
req.retryCount = 0;
next();
});
// Proxy endpoints.
proxyRouter.use("/openai", addV1, openai);
proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-ai", addV1, googleAI);
proxyRouter.use("/mistral-ai", addV1, mistralAI);
proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/aws", aws);
proxyRouter.use("/gcp/claude", addV1, gcp);
proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage.
proxyRouter.get("*", (req, res, next) => {
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
@@ -48,7 +59,8 @@ proxyRouter.get("*", (req, res, next) => {
next();
}
});
// Handle 404s.
// Send a fake client error if user specifies an invalid proxy endpoint.
proxyRouter.use((req, res) => {
sendErrorToClient({
req,
@@ -69,11 +81,3 @@ proxyRouter.use((req, res) => {
});
export { proxyRouter as proxyRouter };
function addV1(req: Request, res: Response, next: NextFunction) {
// Clients don't consistently use the /v1 prefix so we'll add it for them.
if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) {
req.url = `/v1${req.url}`;
}
next();
}
+1
View File
@@ -49,6 +49,7 @@ app.use(
// Don't log the prompt text on transform errors
"body.messages",
"body.prompt",
"body.contents",
],
censor: "********",
},
+92 -146
View File
@@ -3,8 +3,6 @@ import {
AnthropicKey,
AwsBedrockKey,
GcpKey,
AzureOpenAIKey,
GoogleAIKey,
keyPool,
OpenAIKey,
} from "./shared/key-management";
@@ -26,21 +24,14 @@ import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
import { getUniqueIps } from "./proxy/rate-limit";
import { assertNever } from "./shared/utils";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
const CACHE_TTL = 2000;
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
k.service === "openai";
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic";
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
k.service === "google-ai";
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
k.service === "mistral-ai";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp";
@@ -54,14 +45,15 @@ type ModelAggregates = {
overQuota?: number;
pozzed?: number;
awsLogged?: number;
awsSonnet?: number;
awsSonnet35?: number;
awsHaiku?: number;
// needed to disambugiate aws-claude family's variants
awsClaude2?: number;
awsSonnet3?: number;
awsSonnet3_5?: number;
awsHaiku: number;
gcpSonnet?: number;
gcpSonnet35?: number;
gcpHaiku?: number;
queued: number;
queueTime: string;
tokens: number;
};
/** All possible combinations of model family and aggregate type. */
@@ -93,14 +85,10 @@ type AnthropicInfo = BaseFamilyInfo & {
};
type AwsInfo = BaseFamilyInfo & {
privacy?: string;
sonnetKeys?: number;
sonnet35Keys?: number;
haikuKeys?: number;
enabledVariants?: string;
};
type GcpInfo = BaseFamilyInfo & {
sonnetKeys?: number;
sonnet35Keys?: number;
haikuKeys?: number;
enabledVariants?: string;
};
// prettier-ignore
@@ -108,12 +96,10 @@ export type ServiceInfo = {
uptime: number;
endpoints: {
openai?: string;
openai2?: string;
anthropic?: string;
"anthropic-claude-3"?: string;
"google-ai"?: string;
"mistral-ai"?: string;
aws?: string;
"aws"?: string;
gcp?: string;
azure?: string;
"openai-image"?: string;
@@ -151,7 +137,6 @@ export type ServiceInfo = {
const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
openai: {
openai: `%BASE%/openai`,
openai2: `%BASE%/openai/turbo-instruct`,
"openai-image": `%BASE%/openai-image`,
},
anthropic: {
@@ -164,7 +149,8 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
"mistral-ai": `%BASE%/mistral-ai`,
},
aws: {
aws: `%BASE%/aws/claude`,
"aws-claude": `%BASE%/aws/claude`,
"aws-mistral": `%BASE%/aws/mistral`,
},
gcp: {
gcp: `%BASE%/gcp/claude`,
@@ -175,7 +161,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
},
};
const modelStats = new Map<ModelAggregateKey, number>();
const familyStats = new Map<ModelAggregateKey, number>();
const serviceStats = new Map<keyof AllStats, number>();
let cachedInfo: ServiceInfo | undefined;
@@ -192,7 +178,7 @@ export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo {
.concat("turbo")
);
modelStats.clear();
familyStats.clear();
serviceStats.clear();
keys.forEach(addKeyToAggregates);
@@ -311,150 +297,102 @@ function increment<T extends keyof AllStats | ModelAggregateKey>(
) {
map.set(key, (map.get(key) || 0) + delta);
}
const addToService = increment.bind(null, serviceStats);
const addToFamily = increment.bind(null, familyStats);
function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "proompts", k.promptCount);
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
increment(
serviceStats,
"mistral-ai__keys",
k.service === "mistral-ai" ? 1 : 0
);
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "gcp__keys", k.service === "gcp" ? 1 : 0);
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
addToService("proompts", k.promptCount);
addToService("openai__keys", k.service === "openai" ? 1 : 0);
addToService("anthropic__keys", k.service === "anthropic" ? 1 : 0);
addToService("google-ai__keys", k.service === "google-ai" ? 1 : 0);
addToService("mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
addToService("aws__keys", k.service === "aws" ? 1 : 0);
addToService("gcp__keys", k.service === "gcp" ? 1 : 0);
addToService("azure__keys", k.service === "azure" ? 1 : 0);
let sumTokens = 0;
let sumCost = 0;
const incrementGenericFamilyStats = (f: ModelFamily) => {
const tokens = (k as any)[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
addToFamily(`${f}__tokens`, tokens);
addToFamily(`${f}__revoked`, k.isRevoked ? 1 : 0);
addToFamily(`${f}__active`, k.isDisabled ? 0 : 1);
};
switch (k.service) {
case "openai":
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
increment(
serviceStats,
"openai__uncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
addToService("openai__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
incrementGenericFamilyStats(f);
addToFamily(`${f}__trial`, k.isTrial ? 1 : 0);
addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0);
});
break;
case "azure":
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
});
break;
case "anthropic": {
case "anthropic":
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
addToService("anthropic__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__trial`, k.tier === "free" ? 1 : 0);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
increment(modelStats, `${f}__pozzed`, k.isPozzed ? 1 : 0);
});
increment(
serviceStats,
"anthropic__uncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
break;
}
case "google-ai": {
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((family) => {
const tokens = k[`${family}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(family, tokens);
increment(modelStats, `${family}__tokens`, tokens);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
incrementGenericFamilyStats(f);
addToFamily(`${f}__trial`, k.tier === "free" ? 1 : 0);
addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0);
addToFamily(`${f}__pozzed`, k.isPozzed ? 1 : 0);
});
break;
}
case "mistral-ai": {
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
break;
}
case "aws": {
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
increment(modelStats, `aws-claude__awsSonnet35`, k.sonnet35Enabled ? 1 : 0);
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
k.modelFamilies.forEach(incrementGenericFamilyStats);
if (!k.isDisabled) {
// Don't add revoked keys to available AWS variants
k.modelIds.forEach((id) => {
if (id.includes("claude-3-sonnet")) {
addToFamily(`aws-claude__awsSonnet3`, 1);
} else if (id.includes("claude-3-5-sonnet")) {
addToFamily(`aws-claude__awsSonnet3_5`, 1);
} else if (id.includes("claude-3-haiku")) {
addToFamily(`aws-claude__awsHaiku`, 1);
} else if (id.includes("claude-v2")) {
addToFamily(`aws-claude__awsClaude2`, 1);
}
});
}
// Ignore revoked keys for aws logging stats, but include keys where the
// logging status is unknown.
const countAsLogged =
k.lastChecked && !k.isDisabled && k.awsLoggingStatus === "enabled";
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
addToFamily(`aws-claude__awsLogged`, countAsLogged ? 1 : 0);
break;
}
case "gcp": {
case "gcp":
if (!keyIsGcpKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
increment(modelStats, `gcp-claude__gcpSonnet`, k.sonnetEnabled ? 1 : 0);
increment(modelStats, `gcp-claude__gcpSonnet35`, k.sonnet35Enabled ? 1 : 0);
increment(modelStats, `gcp-claude__gcpHaiku`, k.haikuEnabled ? 1 : 0);
k.modelFamilies.forEach(incrementGenericFamilyStats);
// TODO: add modelIds to GcpKey
break;
// These services don't have any additional stats to track.
case "azure":
case "google-ai":
case "mistral-ai":
k.modelFamilies.forEach(incrementGenericFamilyStats);
break;
}
default:
assertNever(k.service);
}
increment(serviceStats, "tokens", sumTokens);
increment(serviceStats, "tokenCost", sumCost);
addToService("tokens", sumTokens);
addToService("tokenCost", sumCost);
}
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const tokens = modelStats.get(`${family}__tokens`) || 0;
const tokens = familyStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
activeKeys: modelStats.get(`${family}__active`) || 0,
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
activeKeys: familyStats.get(`${family}__active`) || 0,
revokedKeys: familyStats.get(`${family}__revoked`) || 0,
};
// Add service-specific stats to the info object.
@@ -462,8 +400,8 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const service = MODEL_FAMILY_SERVICE[family];
switch (service) {
case "openai":
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0;
info.trialKeys = familyStats.get(`${family}__trial`) || 0;
// Delete trial/revoked keys for non-turbo families.
// Trials are turbo 99% of the time, and if a key is invalid we don't
@@ -474,16 +412,25 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
}
break;
case "anthropic":
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0;
info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0;
info.trialKeys = familyStats.get(`${family}__trial`) || 0;
info.prefilledKeys = familyStats.get(`${family}__pozzed`) || 0;
break;
case "aws":
if (family === "aws-claude") {
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
info.sonnet35Keys = modelStats.get(`${family}__awsSonnet35`) || 0;
info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0;
const logged = modelStats.get(`${family}__awsLogged`) || 0;
const logged = familyStats.get(`${family}__awsLogged`) || 0;
const variants = new Set<string>();
if (familyStats.get(`${family}__awsClaude2`) || 0)
variants.add("claude2");
if (familyStats.get(`${family}__awsSonnet3`) || 0)
variants.add("sonnet3");
if (familyStats.get(`${family}__awsSonnet3_5`) || 0)
variants.add("sonnet3.5");
if (familyStats.get(`${family}__awsHaiku`) || 0)
variants.add("haiku");
info.enabledVariants = variants.size
? `${Array.from(variants).join(",")}`
: undefined;
if (logged > 0) {
info.privacy = config.allowAwsLogging
? `AWS logging verification inactive. Prompts could be logged.`
@@ -493,9 +440,8 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
break;
case "gcp":
if (family === "gcp-claude") {
info.sonnetKeys = modelStats.get(`${family}__gcpSonnet`) || 0;
info.sonnet35Keys = modelStats.get(`${family}__gcpSonnet35`) || 0;
info.haikuKeys = modelStats.get(`${family}__gcpHaiku`) || 0;
// TODO: implement
info.enabledVariants = "not implemented";
}
break;
}
+23 -1
View File
@@ -19,7 +19,12 @@ const AnthropicV1BaseSchema = z
top_k: z.coerce.number().optional(),
top_p: z.coerce.number().optional(),
metadata: z.object({ user_id: z.string().optional() }).optional(),
tools: z.array(z.any()).optional(),
tool_choice: z.any().optional(),
})
.omit(
Boolean(config.allowOpenAIToolUsage) ? {} : { tools: true, tool_choice: true }
)
.strip();
// https://docs.anthropic.com/claude/reference/complete_post [deprecated]
@@ -44,6 +49,18 @@ const AnthropicV1MessageMultimodalContentSchema = z.array(
data: z.string(),
}),
}),
z.object({
type: z.literal("tool_use"),
id: z.string(),
name: z.string(),
input: z.object({}).passthrough(),
}),
z.object({
type: z.literal("tool_result"),
tool_use_id: z.string(),
is_error: z.boolean().optional(),
content: z.union([z.string(), z.object({}).passthrough()]).optional(),
}),
])
);
@@ -63,7 +80,12 @@ export const AnthropicV1MessagesSchema = AnthropicV1BaseSchema.merge(
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
system: z.string().optional(),
system: z
.union([
z.string(),
z.array(z.object({ type: z.literal("text"), text: z.string() })),
])
.optional(),
})
);
export type AnthropicChatMessage = z.infer<
+1 -1
View File
@@ -31,7 +31,7 @@ export const GoogleAIV1GenerateContentSchema = z
topP: z.number().optional(),
topK: z.number().optional(),
stopSequences: z.array(z.string().max(500)).max(5).optional(),
}),
}).default({}),
})
.strip();
export type GoogleAIChatMessage = z.infer<
+7 -1
View File
@@ -21,7 +21,11 @@ import {
GoogleAIV1GenerateContentSchema,
transformOpenAIToGoogleAI,
} from "./google-ai";
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
import {
MistralAIV1ChatCompletionsSchema,
MistralAIV1TextCompletionsSchema,
transformMistralChatToText,
} from "./mistral-ai";
export { OpenAIChatMessage } from "./openai";
export {
@@ -49,6 +53,7 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"openai->openai-text": transformOpenAIToOpenAIText,
"openai->openai-image": transformOpenAIToOpenAIImage,
"openai->google-ai": transformOpenAIToGoogleAI,
"mistral-ai->mistral-text": transformMistralChatToText,
};
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
@@ -59,4 +64,5 @@ export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
"mistral-text": MistralAIV1TextCompletionsSchema,
};
+120 -14
View File
@@ -1,15 +1,34 @@
import { z } from "zod";
import { OPENAI_OUTPUT_MAX } from "./openai";
import { Template } from "@huggingface/jinja";
import { APIFormatTransformer } from "./index";
import { logger } from "../../logger";
const MistralChatMessageSchema = z.object({
role: z.enum(["system", "user", "assistant", "tool"]), // TODO: implement tools
content: z.string(),
prefix: z.boolean().optional(),
});
const MistralMessagesSchema = z.array(MistralChatMessageSchema).refine(
(input) => {
const prefixIdx = input.findIndex((msg) => Boolean(msg.prefix));
if (prefixIdx === -1) return true; // no prefix messages
const lastIdx = input.length - 1;
const lastMsg = input[lastIdx];
return prefixIdx === lastIdx && lastMsg.role === "assistant";
},
{
message:
"`prefix` can only be set to `true` on the last message, and only for an assistant message.",
}
);
// https://docs.mistral.ai/api#operation/createChatCompletion
export const MistralAIV1ChatCompletionsSchema = z.object({
const BaseMistralAIV1CompletionsSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
})
),
messages: MistralMessagesSchema.optional(),
prompt: z.string().optional(),
temperature: z.number().optional().default(0.7),
top_p: z.number().optional().default(1),
max_tokens: z.coerce
@@ -18,12 +37,50 @@ export const MistralAIV1ChatCompletionsSchema = z.object({
.nullish()
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
stream: z.boolean().optional().default(false),
// Mistral docs say that `stop` can be a string or array but AWS Mistral
// blows up if a string is passed. We must convert it to an array.
stop: z
.union([z.string(), z.array(z.string())])
.optional()
.default([])
.transform((v) => (Array.isArray(v) ? v : [v])),
random_seed: z.number().int().min(0).optional(),
response_format: z
.object({ type: z.enum(["text", "json_object"]) })
.optional(),
safe_prompt: z.boolean().optional().default(false),
random_seed: z.number().int().optional(),
});
export type MistralAIChatMessage = z.infer<
typeof MistralAIV1ChatCompletionsSchema
>["messages"][0];
export const MistralAIV1ChatCompletionsSchema =
BaseMistralAIV1CompletionsSchema.and(
z.object({ messages: MistralMessagesSchema })
);
export const MistralAIV1TextCompletionsSchema =
BaseMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() }));
/*
Slightly more strict version that only allows a subset of the parameters. AWS
Mistral helpfully returns no details if unsupported parameters are passed so
this list comes from trial and error as of 2024-08-12.
*/
const BaseAWSMistralAIV1CompletionsSchema =
BaseMistralAIV1CompletionsSchema.pick({
temperature: true,
top_p: true,
max_tokens: true,
stop: true,
random_seed: true,
// response_format: true,
// safe_prompt: true,
}).strip();
export const AWSMistralV1ChatCompletionsSchema =
BaseAWSMistralAIV1CompletionsSchema.and(
z.object({ messages: MistralMessagesSchema })
);
export const AWSMistralV1TextCompletionsSchema =
BaseAWSMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() }));
export type MistralAIChatMessage = z.infer<typeof MistralChatMessageSchema>;
export function fixMistralPrompt(
messages: MistralAIChatMessage[]
@@ -31,12 +88,11 @@ export function fixMistralPrompt(
// Mistral uses OpenAI format but has some additional requirements:
// - Only one system message per request, and it must be the first message if
// present.
// - Final message must be a user message.
// - Final message must be a user message, unless it has `prefix: true`.
// - Cannot have multiple messages from the same role in a row.
// While frontends should be able to handle this, we can fix it here in the
// meantime.
return messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
const fixed = messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
if (acc.length === 0) {
acc.push(msg);
return acc;
@@ -57,4 +113,54 @@ export function fixMistralPrompt(
}
return acc;
}, []);
// If the last message is an assistant message, mark it as a prefix. An
// assistant message at the end of the conversation without `prefix: true`
// results in an error.
if (fixed[fixed.length - 1].role === "assistant") {
fixed[fixed.length - 1].prefix = true;
}
return fixed;
}
let jinjaTemplate: Template;
let renderTemplate: (messages: MistralAIChatMessage[]) => string;
function renderMistralPrompt(messages: MistralAIChatMessage[]) {
if (!jinjaTemplate) {
logger.warn("Lazy loading mistral chat template...");
const { chatTemplate, bosToken, eosToken } =
require("./templates/mistral-template").MISTRAL_TEMPLATE;
jinjaTemplate = new Template(chatTemplate);
renderTemplate = (messages) =>
jinjaTemplate.render({
messages,
bos_token: bosToken,
eos_token: eosToken,
});
}
return renderTemplate(messages);
}
/**
* Attempts to convert a Mistral chat completions request to a text completions,
* using the official prompt template published by Mistral.
*/
export const transformMistralChatToText: APIFormatTransformer<
typeof MistralAIV1TextCompletionsSchema
> = async (req) => {
const { body } = req;
const result = MistralAIV1ChatCompletionsSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid Mistral chat completions request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt = renderMistralPrompt(messages);
return { ...rest, prompt, messages: undefined };
};
+1 -1
View File
@@ -52,7 +52,7 @@ export const OpenAIV1ChatCompletionSchema = z
.number()
.int()
.nullish()
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
.default(Math.min(OPENAI_OUTPUT_MAX, 16384))
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
@@ -0,0 +1,36 @@
export const MISTRAL_TEMPLATE = {
bosToken: "<s>",
eosToken: "</s>",
chatTemplate: `"{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages %}
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- endfor %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if loop.last and system_message is defined %}
{{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}
{%- else %}
{{- "[INST] " + message["content"] + "[/INST]" }}
{%- endif %}
{%- elif message["role"] == "assistant" %}
{%- if loop.last and message.prefix is defined and message.prefix %}
{{- " " + message["content"] }}
{%- else %}
{{- " " + message["content"] + eos_token}}
{%- endif %}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}`,
};
+18
View File
@@ -0,0 +1,18 @@
/** Module for generating and verifying HMAC signatures. */
import crypto from "crypto";
import { SECRET_SIGNING_KEY } from "../config";
/**
* Generates a HMAC signature for the given message. Optionally salts the
* key with a provided string.
*/
export function signMessage(msg: any, salt: string = ""): string {
const hmac = crypto.createHmac("sha256", SECRET_SIGNING_KEY + salt);
if (typeof msg === "object") {
hmac.update(JSON.stringify(msg));
} else {
hmac.update(msg);
}
return hmac.digest("hex");
}
+2 -2
View File
@@ -1,9 +1,9 @@
import { doubleCsrf } from "csrf-csrf";
import express from "express";
import { config, COOKIE_SECRET } from "../config";
import { config, SECRET_SIGNING_KEY } from "../config";
const { generateToken, doubleCsrfProtection } = doubleCsrf({
getSecret: () => COOKIE_SECRET,
getSecret: () => SECRET_SIGNING_KEY,
cookieName: "csrf",
cookieOptions: {
sameSite: "strict",
@@ -1,5 +1,5 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { AnthropicModelFamily, getClaudeModelFamily } from "../../models";
@@ -23,10 +23,6 @@ type AnthropicKeyUsage = {
export interface AnthropicKey extends Key, AnthropicKeyUsage {
readonly service: "anthropic";
readonly modelFamilies: AnthropicModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/**
* Whether this key requires a special preamble. For unclear reasons, some
* Anthropic keys will throw an error if the prompt does not begin with a
@@ -217,22 +213,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
key[`${getClaudeModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+214 -48
View File
@@ -1,13 +1,31 @@
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
import axios, { AxiosError, AxiosHeaders, AxiosRequestConfig } from "axios";
import { URL } from "url";
import { config } from "../../../config";
import { getAwsBedrockModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base";
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
import { AwsBedrockModelFamily } from "../../models";
import { config } from "../../../config";
type ParentModelId = string;
type AliasModelId = string;
type ModuleAliasTuple = [ParentModelId, ...AliasModelId[]];
const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [
["anthropic.claude-v2", "anthropic.claude-v2:1"],
["anthropic.claude-3-sonnet-20240229-v1:0"],
["anthropic.claude-3-haiku-20240307-v1:0"],
["anthropic.claude-3-opus-20240229-v1:0"],
["anthropic.claude-3-5-sonnet-20240620-v1:0"],
["mistral.mistral-7b-instruct-v0:2"],
["mistral.mixtral-8x7b-instruct-v0:1"],
["mistral.mistral-large-2402-v1:0"],
["mistral.mistral-large-2407-v1:0"],
["mistral.mistral-small-2402-v1:0"], // Seems to return 400
];
const KEY_CHECK_BATCH_SIZE = 2; // AWS checker needs to do lots of concurrent requests so should lower the batch size
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const AMZ_HOST =
@@ -15,6 +33,8 @@ const AMZ_HOST =
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
const GET_LIST_INFERENCE_PROFILES_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/inference-profiles?maxResults=1000`;
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
const TEST_MESSAGES = [
@@ -24,6 +44,22 @@ const TEST_MESSAGES = [
type AwsError = { error: {} };
type GetInferenceProfilesResponse = {
inferenceProfileSummaries: {
inferenceProfileId: string;
inferenceProfileName: string;
inferenceProfileArn: string;
description?: string;
createdAt?: string;
updatedAt?: string;
status: "ACTIVE" | unknown;
type: "SYSTEM_DEFINED" | unknown;
models: {
modelArn?: string;
}[];
}[];
};
type GetLoggingConfigResponse = {
loggingConfig: null | {
cloudWatchConfig: null | unknown;
@@ -42,63 +78,67 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
service: "aws",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
keyCheckBatchSize: KEY_CHECK_BATCH_SIZE,
updateKey,
});
}
protected async testKeyOrFail(key: AwsBedrockKey) {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
let checks: Promise<boolean>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
checks = [
this.invokeModel("anthropic.claude-v2", key),
this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key),
this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key),
this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key),
this.invokeModel("anthropic.claude-3-5-sonnet-20240620-v1:0", key),
];
try {
await this.checkInferenceProfiles(key);
} catch (e) {
const asError = e as AxiosError<AwsError>;
const data = asError.response?.data;
this.log.warn(
{ key: key.hash, error: e.message, data },
"Cannot list inference profiles.\n\
Principal may be missing `AmazonBedrockFullAccess`, or has no policy allowing action `bedrock:ListInferenceProfiles` against resource `arn:aws:bedrock:*:*:inference-profile/*`.\n\
Requests will be made without inference profiles using on-demand quotas, which may be subject to more restrictive rate limits.\n\
See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-prereq.html."
);
}
}
checks.unshift(this.checkLoggingConfiguration(key));
const [_logging, claudeV2, sonnet, haiku, opus, sonnet35] =
await Promise.all(checks);
this.log.debug(
{ key: key.hash, _logging, claudeV2, sonnet, haiku, opus, sonnet35 },
"AWS model tests complete."
// Perform checks for all parent model IDs
const results = await Promise.all(
KNOWN_MODEL_IDS.filter(([model]) =>
// Skip checks for models that are disabled anyway
config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model))
).map(async ([model, ...aliases]) => ({
models: [model, ...aliases],
success: await this.invokeModel(model, key),
}))
);
if (isInitialCheck) {
const families: AwsBedrockModelFamily[] = [];
if (claudeV2 || sonnet || sonnet35 || haiku) families.push("aws-claude");
if (opus) families.push("aws-claude-opus");
// Filter out models that are disabled
const modelIds = results
.filter(({ success }) => success)
.flatMap(({ models }) => models);
if (families.length === 0) {
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
);
return this.updateKey(key.hash, { isDisabled: true });
}
this.updateKey(key.hash, {
sonnetEnabled: sonnet,
haikuEnabled: haiku,
sonnet35Enabled: sonnet35,
modelFamilies: families,
});
if (modelIds.length === 0) {
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
);
return this.updateKey(key.hash, { isDisabled: true });
}
this.updateKey(key.hash, {
modelIds,
modelFamilies: Array.from(
new Set(modelIds.map(getAwsBedrockModelFamily))
),
});
this.log.info(
{
key: key.hash,
sonnet,
haiku,
families: key.modelFamilies,
logged: key.awsLoggingStatus,
families: key.modelFamilies,
models: key.modelIds,
},
"Checked key."
);
@@ -169,7 +209,52 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
* key has access to the model, false if it does not. Throws an error if the
* key is disabled.
*/
private async invokeModel(model: string, key: AwsBedrockKey) {
private async invokeModel(
model: string,
key: AwsBedrockKey
): Promise<boolean> {
if (model.includes("claude")) {
// If inference profiles are available, try testing model with them.
// If they are not available or the invocation fails with the inference
// profile, fall back to regular model ID.
const { region } = AwsKeyChecker.getCredentialsFromKey(key);
const continent = region.split("-")[0];
const profile = key.inferenceProfileIds.find(
(id) => `${continent}.${model}` === id
);
if (profile) {
this.log.debug(
{ key: key.hash, model, profile },
"Testing model via inference profile."
);
let result: boolean;
try {
result = await this.testClaudeModel(key, profile);
} catch (e) {
this.log.error(
{ key: key.hash, model, profile, error: e.message },
"Error testing model with inference profile; trying model ID directly."
);
result = false;
}
// If the profile worked, we'll return success. Caller will add the
// model (not the profile) to the list of enabled models, but the
// profile will be used when the key is used for inference.
if (result) return true;
}
return this.testClaudeModel(key, model);
} else if (model.includes("mistral")) {
return this.testMistralModel(key, model);
}
throw new Error("AwsKeyChecker#invokeModel: no implementation for model");
}
private async testClaudeModel(
key: AwsBedrockKey,
model: string
): Promise<boolean> {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
// This is not a valid invocation payload, but a 400 response indicates that
// the principal at least has permission to invoke the model.
@@ -196,20 +281,25 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
const errorMessage = data?.message;
// We only allow one type of 403 error, and we only allow it for one model.
// This message indicates the key is valid but this particular model is not
// accessible. Other 403s may indicate the key is not usable.
if (
status === 403 &&
errorMessage?.match(/access to the model with the specified model ID/)
) {
this.log.debug(
{ key: key.hash, model, errorType, data, status, headers },
"Model is not available (principal does not have access)."
);
return false;
}
// ResourceNotFound typically indicates that the tested model cannot be used
// on the configured region for this set of credentials.
if (status === 404) {
this.log.debug(
{ region: creds.region, model, key: key.hash },
"Model not supported in this AWS region."
"Model is not available (not supported in this AWS region)."
);
return false;
}
@@ -219,16 +309,91 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
const correctErrorType = errorType === "ValidationException";
const correctErrorMessage = errorMessage?.match(/max_tokens/);
if (!correctErrorType || !correctErrorMessage) {
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is not available (request rejected)."
);
return false;
}
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"AWS InvokeModel test successful."
"Model is available."
);
return true;
}
private async testMistralModel(
key: AwsBedrockKey,
model: string
): Promise<boolean> {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const payload = {
max_tokens: -1,
prompt: "<s>[INST] What is your favourite condiment? [/INST]</s>",
};
const config: AxiosRequestConfig = {
method: "POST",
url: POST_INVOKE_MODEL_URL(creds.region, model),
data: payload,
validateStatus: (status) => [400, 403, 404].includes(status),
headers: {
"content-type": "application/json",
accept: "*/*",
},
};
await AwsKeyChecker.signRequestForAws(config, key);
const response = await axios.request(config);
const { data, status, headers } = response;
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
const errorMessage = data?.message;
if (status === 403 || status === 404) {
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is not available (no access or unsupported region)."
);
return false;
}
const isBadRequest = status === 400;
const isValidationError = errorMessage?.match(/validation error/i);
if (isBadRequest && !isValidationError) {
this.log.debug(
{ key: key.hash, model, errorType, data, status, headers },
"Model is not available (request rejected)."
);
return false;
}
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is available."
);
return true;
}
private async checkInferenceProfiles(key: AwsBedrockKey) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const req: AxiosRequestConfig = {
method: "GET",
url: GET_LIST_INFERENCE_PROFILES_URL(creds.region),
headers: { accept: "application/json" },
};
await AwsKeyChecker.signRequestForAws(req, key);
const { data } = await axios.request<GetInferenceProfilesResponse>(req);
const { inferenceProfileSummaries } = data;
const profileIds = inferenceProfileSummaries.map(
(p) => p.inferenceProfileId
);
this.log.debug(
{ key: key.hash, profileIds, region: creds.region },
"Inference profiles found."
);
this.updateKey(key.hash, { inferenceProfileIds: profileIds });
}
private async checkLoggingConfiguration(key: AwsBedrockKey) {
if (config.allowAwsLogging) {
// Don't check logging status if we're allowing it to reduce API calls.
@@ -297,7 +462,8 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
method,
protocol: "https:",
hostname: url.hostname,
path: url.pathname + url.search,
path: url.pathname,
query: Object.fromEntries(url.searchParams),
headers: { Host: url.hostname, ...plainHeaders },
});
+46 -76
View File
@@ -1,10 +1,11 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
import { AwsKeyChecker } from "./checker";
import { PaymentRequiredError } from "../../errors";
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { AwsKeyChecker } from "./checker";
type AwsBedrockKeyUsage = {
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
@@ -13,10 +14,6 @@ type AwsBedrockKeyUsage = {
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
readonly service: "aws";
readonly modelFamilies: AwsBedrockModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/**
* The confirmed logging status of this key. This is "unknown" until we
* receive a response from the AWS API. Keys which are logged, or not
@@ -24,9 +21,8 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
* set.
*/
awsLoggingStatus: "unknown" | "disabled" | "enabled";
sonnetEnabled: boolean;
haikuEnabled: boolean;
sonnet35Enabled: boolean;
modelIds: string[];
inferenceProfileIds: string[];
}
/**
@@ -76,11 +72,14 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
sonnetEnabled: true,
haikuEnabled: false,
sonnet35Enabled: false,
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
inferenceProfileIds: [],
["aws-claudeTokens"]: 0,
["aws-claude-opusTokens"]: 0,
["aws-mistral-tinyTokens"]: 0,
["aws-mistral-smallTokens"]: 0,
["aws-mistral-mediumTokens"]: 0,
["aws-mistral-largeTokens"]: 0,
};
this.keys.push(newKey);
}
@@ -99,41 +98,35 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
}
public get(model: string) {
let neededVariantId = model;
// This function accepts both Anthropic/Mistral IDs and AWS IDs.
// Generally all AWS model IDs are supersets of the original vendor IDs.
// Claude 2 is the only model that breaks this convention; Anthropic calls
// it claude-2 but AWS calls it claude-v2.
if (model.includes("claude-2")) neededVariantId = "claude-v2";
const neededFamily = getAwsBedrockModelFamily(model);
// this is a horrible mess
// each of these should be separate model families, but adding model
// families is not low enough friction for the rate at which aws claude
// model variants are added.
const needsSonnet35 =
model.includes("claude-3-5-sonnet") && neededFamily === "aws-claude";
const needsSonnet =
!needsSonnet35 &&
model.includes("sonnet") &&
neededFamily === "aws-claude";
const needsHaiku = model.includes("haiku") && neededFamily === "aws-claude";
const availableKeys = this.keys.filter((k) => {
const isNotLogged = k.awsLoggingStatus !== "enabled";
// Select keys which
return (
// are enabled
!k.isDisabled &&
(isNotLogged || config.allowAwsLogging) &&
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under aws-claude, while opus is not
(k.haikuEnabled || !needsHaiku) &&
(k.sonnet35Enabled || !needsSonnet35) &&
k.modelFamilies.includes(neededFamily)
// are not logged, unless policy allows it
(config.allowAwsLogging || k.awsLoggingStatus !== "enabled") &&
// have access to the model family we need
k.modelFamilies.includes(neededFamily) &&
// have access to the specific variant we need
k.modelIds.some((m) => m.includes(neededVariantId))
);
});
this.log.debug(
{
model,
neededFamily,
needsSonnet,
needsHaiku,
needsSonnet35,
availableKeys: availableKeys.length,
requestedModel: model,
selectedVariant: neededVariantId,
selectedFamily: neededFamily,
totalKeys: this.keys.length,
availableKeys: availableKeys.length,
},
"Selecting AWS key"
);
@@ -144,30 +137,22 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
);
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
/**
* Comparator for prioritizing keys on inference profile compatibility.
* Requests made via inference profiles have higher rate limits so we want
* to use keys with compatible inference profiles first.
*/
const hasInferenceProfile = (
a: AwsBedrockKey,
b: AwsBedrockKey
) => {
const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model));
const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model));
return aMatch - bMatch;
};
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -195,22 +180,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
// TODO: same exact behavior for three providers, should be refactored
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+10 -52
View File
@@ -1,10 +1,13 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { PaymentRequiredError } from "../../errors";
import { logger } from "../../../logger";
import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models";
import { PaymentRequiredError } from "../../errors";
import {
AzureOpenAIModelFamily,
getAzureOpenAIModelFamily,
} from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { AzureOpenAIKeyChecker } from "./checker";
type AzureOpenAIKeyUsage = {
@@ -14,10 +17,6 @@ type AzureOpenAIKeyUsage = {
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
readonly service: "azure";
readonly modelFamilies: AzureOpenAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
contentFiltering: boolean;
}
@@ -105,30 +104,8 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
);
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
const selectedKey = prioritizeKeys(availableKeys)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -156,26 +133,7 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
}
// TODO: all of this shit is duplicate code
public getLockoutPeriod(family: AzureOpenAIModelFamily) {
const activeKeys = this.keys.filter(
(key) => !key.isDisabled && key.modelFamilies.includes(family)
);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+51 -34
View File
@@ -6,10 +6,12 @@ import { GcpModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const GCP_HOST =
process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
`https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
`https://${GCP_HOST.replace(
"%REGION%",
region
)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
const TEST_MESSAGES = [
{ role: "user", content: "Hi!" },
{ role: "assistant", content: "Hello!" },
@@ -23,6 +25,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
service: "gcp",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
@@ -38,9 +41,8 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
];
const [sonnet, haiku, opus, sonnet35] =
await Promise.all(checks);
const [sonnet, haiku, opus, sonnet35] = await Promise.all(checks);
this.log.debug(
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
"GCP model initial tests complete."
@@ -66,20 +68,17 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
});
} else {
if (key.haikuEnabled) {
await this.invokeModel("claude-3-haiku@20240307", key, false)
await this.invokeModel("claude-3-haiku@20240307", key, false);
} else if (key.sonnetEnabled) {
await this.invokeModel("claude-3-sonnet@20240229", key, false)
await this.invokeModel("claude-3-sonnet@20240229", key, false);
} else if (key.sonnet35Enabled) {
await this.invokeModel("claude-3-5-sonnet@20240620", key, false)
await this.invokeModel("claude-3-5-sonnet@20240620", key, false);
} else {
await this.invokeModel("claude-3-opus@20240229", key, false)
await this.invokeModel("claude-3-opus@20240229", key, false);
}
this.updateKey(key.hash, { lastChecked: Date.now() });
this.log.debug(
{ key: key.hash},
"GCP key check complete."
);
this.log.debug({ key: key.hash }, "GCP key check complete.");
}
this.log.info(
@@ -134,8 +133,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
*/
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
const creds = GcpKeyChecker.getCredentialsFromKey(key);
const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey)
const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT)
const signedJWT = await GcpKeyChecker.createSignedJWT(
creds.clientEmail,
creds.privateKey
);
const [accessToken, jwtError] =
await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT);
if (accessToken === null) {
this.log.warn(
{ key: key.hash, jwtError },
@@ -151,15 +154,19 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
const { data, status } = await axios.post(
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
payload,
{
{
headers: GcpKeyChecker.getRequestHeaders(accessToken),
validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300
validateStatus: initial
? () => true
: (status: number) => status >= 200 && status < 300,
}
);
this.log.debug({ key: key.hash, data }, "Response from GCP");
if (initial) {
return (status >= 200 && status < 300) || (status === 429 || status === 529);
return (
(status >= 200 && status < 300) || status === 429 || status === 529
);
}
return true;
@@ -178,10 +185,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
let cryptoKey = await crypto.subtle.importKey(
"pkcs8",
GcpKeyChecker.str2ab(atob(pkey)),
{
name: "RSASSA-PKCS1-v1_5",
hash: { name: "SHA-256" },
},
{ name: "RSASSA-PKCS1-v1_5", hash: { name: "SHA-256" } },
false,
["sign"]
);
@@ -190,10 +194,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = {
alg: "RS256",
typ: "JWT",
};
const header = { alg: "RS256", typ: "JWT" };
const payload = {
iss: email,
@@ -203,8 +204,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
scope: "https://www.googleapis.com/auth/cloud-platform",
};
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header));
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload));
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(header)
);
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(payload)
);
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
@@ -218,7 +223,9 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
return `${unsignedToken}.${encodedSignature}`;
}
static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> {
static async exchangeJwtForAccessToken(
signed_jwt: string
): Promise<[string | null, string]> {
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
@@ -252,7 +259,11 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16))));
base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
@@ -260,7 +271,10 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
}
static getRequestHeaders(accessToken: string) {
return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" };
return {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
};
}
static getCredentialsFromKey(key: GcpKey) {
@@ -269,9 +283,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
throw new Error("Invalid GCP key");
}
const privateKey = rawPrivateKey
.replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '')
.replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.trim();
return { projectId, clientEmail, region, privateKey };
}
}
+7 -47
View File
@@ -1,10 +1,11 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { GcpModelFamily, getGcpModelFamily } from "../../models";
import { GcpKeyChecker } from "./checker";
import { PaymentRequiredError } from "../../errors";
import { GcpModelFamily, getGcpModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { GcpKeyChecker } from "./checker";
type GcpKeyUsage = {
[K in GcpModelFamily as `${K}Tokens`]: number;
@@ -13,10 +14,6 @@ type GcpKeyUsage = {
export interface GcpKey extends Key, GcpKeyUsage {
readonly service: "gcp";
readonly modelFamilies: GcpModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
sonnetEnabled: boolean;
haikuEnabled: boolean;
sonnet35Enabled: boolean;
@@ -134,30 +131,8 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
);
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
const selectedKey = prioritizeKeys(availableKeys)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -185,22 +160,7 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
// TODO: same exact behavior for three providers, should be refactored
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
@@ -1,9 +1,10 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models";
import { PaymentRequiredError } from "../../errors";
import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { GoogleAIKeyChecker } from "./checker";
// Note that Google AI is not the same as Vertex AI, both are provided by
@@ -28,10 +29,6 @@ type GoogleAIKeyUsage = {
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
readonly service: "google-ai";
readonly modelFamilies: GoogleAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/** All detected model IDs on this key. */
modelIds: string[];
}
@@ -112,29 +109,10 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
throw new PaymentRequiredError("No Google AI keys available");
}
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const keysByPriority = prioritizeKeys(availableKeys);
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -162,22 +140,7 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
key[`${getGoogleAIModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+30 -3
View File
@@ -9,7 +9,8 @@ export type APIFormat =
| "anthropic-chat" // Anthropic's newer messages array format
| "anthropic-text" // Legacy flat string prompt format
| "google-ai"
| "mistral-ai";
| "mistral-ai"
| "mistral-text"
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
@@ -30,6 +31,10 @@ export interface Key {
lastChecked: number;
/** Hash of the key, for logging and to find the key in the pool. */
hash: string;
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
}
/*
@@ -58,10 +63,32 @@ export interface KeyProvider<T extends Key = Key> {
recheck(): void;
}
export function createGenericGetLockoutPeriod<T extends Key>(
getKeys: () => T[]
) {
return function (this: unknown, family?: ModelFamily): number {
const keys = getKeys();
const activeKeys = keys.filter(
(k) => !k.isDisabled && (!family || k.modelFamilies.includes(family))
);
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
};
}
export const keyPool = new KeyPool();
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { AwsBedrockKey } from "./aws/provider";
export { GcpKey } from "./gcp/provider";
export { AzureOpenAIKey } from "./azure/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { MistralAIKey } from "./mistral-ai/provider";
export { OpenAIKey } from "./openai/provider";
@@ -7,6 +7,7 @@ type KeyCheckerOptions<TKey extends Key = Key> = {
service: string;
keyCheckPeriod: number;
minCheckInterval: number;
keyCheckBatchSize?: number;
recurringChecksEnabled?: boolean;
updateKey: (hash: string, props: Partial<TKey>) => void;
};
@@ -22,6 +23,8 @@ export abstract class KeyCheckerBase<TKey extends Key> {
* than this.
*/
protected readonly keyCheckPeriod: number;
/** Maximum number of keys to check simultaneously. */
protected readonly keyCheckBatchSize: number;
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
protected readonly keys: TKey[] = [];
protected log: pino.Logger;
@@ -33,6 +36,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
this.keyCheckPeriod = opts.keyCheckPeriod;
this.minCheckInterval = opts.minCheckInterval;
this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true;
this.keyCheckBatchSize = opts.keyCheckBatchSize ?? 12;
this.updateKey = opts.updateKey;
this.service = opts.service;
this.log = logger.child({ module: "key-checker", service: opts.service });
@@ -78,7 +82,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
if (numUnchecked > 0) {
const keycheckBatch = uncheckedKeys.slice(0, 12);
const keycheckBatch = uncheckedKeys.slice(0, this.keyCheckBatchSize);
this.timeout = setTimeout(async () => {
try {
@@ -1,10 +1,11 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
import { MistralAIKeyChecker } from "./checker";
import { HttpError } from "../../errors";
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { MistralAIKeyChecker } from "./checker";
type MistralAIKeyUsage = {
[K in MistralAIModelFamily as `${K}Tokens`]: number;
@@ -13,10 +14,6 @@ type MistralAIKeyUsage = {
export interface MistralAIKey extends Key, MistralAIKeyUsage {
readonly service: "mistral-ai";
readonly modelFamilies: MistralAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
}
/**
@@ -98,30 +95,8 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
throw new HttpError(402, "No Mistral AI keys available");
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
const selectedKey = prioritizeKeys(availableKeys)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -150,22 +125,7 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
key[`${family}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+1 -2
View File
@@ -26,8 +26,6 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
isTrial: boolean;
/** Set when key check returns a non-transient 429. */
isOverQuota: boolean;
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/**
* Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a
* number.
@@ -111,6 +109,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
.digest("hex")
.slice(0, 8)}`,
rateLimitedAt: 0,
rateLimitedUntil: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
turboTokens: 0,
@@ -0,0 +1,39 @@
import { Key } from "./index";
/**
* Given a list of keys, returns a new list of keys sorted from highest to
* lowest priority. Keys are prioritized in the following order:
*
* 1. Keys which are not rate limited
* a. If all keys were rate limited recently, select the least-recently
* rate limited key.
* b. Otherwise, select the first key.
* 2. Keys which have not been used in the longest time
* 3. Keys according to the custom comparator, if provided
* @param keys The list of keys to sort
* @param customComparator A custom comparator function to use for sorting
*/
export function prioritizeKeys<T extends Key>(
keys: T[],
customComparator?: (a: T, b: T) => number
) {
const now = Date.now();
return keys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < a.rateLimitedUntil;
const bRateLimited = now - b.rateLimitedAt < b.rateLimitedUntil;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
if (customComparator) {
const result = customComparator(a, b);
if (result !== 0) return result;
}
return a.lastUsed - b.lastUsed;
});
}
+26 -5
View File
@@ -32,7 +32,9 @@ export type MistralAIModelFamily =
// mistral changes their model classes frequently so these no longer
// correspond to specific models. consider them rough pricing tiers.
"mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large";
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
export type AwsBedrockModelFamily = `aws-${
| AnthropicModelFamily
| MistralAIModelFamily}`;
export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus";
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
export type ModelFamily =
@@ -64,6 +66,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"mistral-large",
"aws-claude",
"aws-claude-opus",
"aws-mistral-tiny",
"aws-mistral-small",
"aws-mistral-medium",
"aws-mistral-large",
"gcp-claude",
"gcp-claude-opus",
"azure-turbo",
@@ -99,6 +105,10 @@ export const MODEL_FAMILY_SERVICE: {
"claude-opus": "anthropic",
"aws-claude": "aws",
"aws-claude-opus": "aws",
"aws-mistral-tiny": "aws",
"aws-mistral-small": "aws",
"aws-mistral-medium": "aws",
"aws-mistral-large": "aws",
"gcp-claude": "gcp",
"gcp-claude-opus": "gcp",
"azure-turbo": "azure",
@@ -120,6 +130,7 @@ export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4o(-\\d{4}-\\d{2}-\\d{2})?$": "gpt4o",
"^chatgpt-4o": "gpt4o",
"^gpt-4o-mini(-\\d{4}-\\d{2}-\\d{2})?$": "turbo", // closest match
"^gpt-4-turbo(-\\d{4}-\\d{2}-\\d{2})?$": "gpt4-turbo",
"^gpt-4-turbo(-preview)?$": "gpt4-turbo",
@@ -180,8 +191,16 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
}
export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
if (model.includes("opus")) return "aws-claude-opus";
return "aws-claude";
// remove vendor and version from AWS model ids
// 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620'
const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d+)?(:\d+)*$/, "$2");
if (["claude", "anthropic"].some((x) => model.includes(x))) {
return `aws-${getClaudeModelFamily(deAwsified)}`;
} else if (model.includes("tral")) {
return `aws-${getMistralAIModelFamily(deAwsified)}`;
}
return `aws-claude`;
}
export function getGcpModelFamily(model: string): GcpModelFamily {
@@ -223,8 +242,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
const model = req.body.model ?? "gpt-3.5-turbo";
let modelFamily: ModelFamily;
// Weird special case for AWS/GCP/Azure because they serve multiple models from
// different vendors, even if currently only one is supported.
// Weird special case for AWS/GCP/Azure because they serve models with
// different API formats, so the outbound API alone is not sufficient to
// determine the partition.
if (req.service === "aws") {
modelFamily = getAwsBedrockModelFamily(model);
} else if (req.service === "gcp") {
@@ -246,6 +266,7 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
modelFamily = getGoogleAIModelFamily(model);
break;
case "mistral-ai":
case "mistral-text":
modelFamily = getMistralAIModelFamily(model);
break;
default:
+3
View File
@@ -67,6 +67,9 @@ async function getTokenCountForMessages({
case "image":
numTokens += await getImageTokenCount(part.source.data);
break;
case "tool_use":
case "tool_result":
break;
default:
throw new Error(`Unsupported Anthropic content type.`);
}
+3 -2
View File
@@ -47,9 +47,9 @@ type GoogleAIChatTokenCountRequest = {
};
type MistralAIChatTokenCountRequest = {
prompt: MistralAIChatMessage[];
prompt: string | MistralAIChatMessage[];
completion?: never;
service: "mistral-ai";
service: "mistral-ai" | "mistral-text";
};
type FlatPromptTokenCountRequest = {
@@ -128,6 +128,7 @@ export async function countTokens({
tokenization_duration_ms: getElapsedMs(time),
};
case "mistral-ai":
case "mistral-text":
return {
...getMistralAITokenCount(prompt ?? completion),
tokenization_duration_ms: getElapsedMs(time),
+1
View File
@@ -431,6 +431,7 @@ function getModelFamilyForQuotaUsage(
case "google-ai":
return getGoogleAIModelFamily(model);
case "mistral-ai":
case "mistral-text":
return getMistralAIModelFamily(model);
default:
assertNever(api);
+3 -3
View File
@@ -1,14 +1,14 @@
import cookieParser from "cookie-parser";
import expressSession from "express-session";
import MemoryStore from "memorystore";
import { config, COOKIE_SECRET } from "../config";
import { config, SECRET_SIGNING_KEY } from "../config";
const ONE_WEEK = 1000 * 60 * 60 * 24 * 7;
const cookieParserMiddleware = cookieParser(COOKIE_SECRET);
const cookieParserMiddleware = cookieParser(SECRET_SIGNING_KEY);
const sessionMiddleware = expressSession({
secret: COOKIE_SECRET,
secret: SECRET_SIGNING_KEY,
resave: false,
saveUninitialized: false,
store: new (MemoryStore(expressSession))({ checkPeriod: ONE_WEEK }),
+8 -19
View File
@@ -2,6 +2,7 @@ import crypto from "crypto";
import express from "express";
import argon2 from "@node-rs/argon2";
import { z } from "zod";
import { signMessage } from "../../shared/hmac-signing";
import {
authenticate,
createUser,
@@ -13,15 +14,13 @@ import { config } from "../../config";
/** Lockout time after verification in milliseconds */
const LOCKOUT_TIME = 1000 * 60; // 60 seconds
/** HMAC key for signing challenges; regenerated on startup */
let hmacSecret = crypto.randomBytes(32).toString("hex");
let powKeySalt = crypto.randomBytes(32).toString("hex");
/**
* Regenerate the HMAC key used for signing challenges. Calling this function
* will invalidate all existing challenges.
* Invalidates any outstanding unsolved challenges.
*/
export function invalidatePowHmacKey() {
hmacSecret = crypto.randomBytes(32).toString("hex");
export function invalidatePowChallenges() {
powKeySalt = crypto.randomBytes(32).toString("hex");
}
const argon2Params = {
@@ -141,16 +140,6 @@ function generateChallenge(clientIp?: string, token?: string): Challenge {
};
}
function signMessage(msg: any): string {
const hmac = crypto.createHmac("sha256", hmacSecret);
if (typeof msg === "object") {
hmac.update(JSON.stringify(msg));
} else {
hmac.update(msg);
}
return hmac.digest("hex");
}
async function verifySolution(
challenge: Challenge,
solution: string,
@@ -225,11 +214,11 @@ router.post("/challenge", (req, res) => {
return;
}
const challenge = generateChallenge(req.ip, refreshToken);
const signature = signMessage(challenge);
const signature = signMessage(challenge, powKeySalt);
res.json({ challenge, signature });
} else {
const challenge = generateChallenge(req.ip);
const signature = signMessage(challenge);
const signature = signMessage(challenge, powKeySalt);
res.json({ challenge, signature });
}
});
@@ -253,7 +242,7 @@ router.post("/verify", async (req, res) => {
}
const { challenge, signature, solution } = result.data;
if (signMessage(challenge) !== signature) {
if (signMessage(challenge, powKeySalt) !== signature) {
res.status(400).json({
error:
"Invalid signature; server may have restarted since challenge was issued. Please request a new challenge.",