24 Commits

Author SHA1 Message Date
nai-degen cfc1290f83 fixes aws keychecker not detecting claude 2.1 2024-08-14 10:46:54 -05:00
nai-degen 14f228f666 always applies Mistral prompt fixes on messages input 2024-08-14 10:44:37 -05:00
nai-degen d264fdd573 adds mistral chat-to-text transformation, for better prefix compatibility 2024-08-13 23:24:36 -05:00
nai-degen 9c3e345720 update deps 2024-08-13 20:31:19 -05:00
nai-degen 37c421bb45 fixes token counting for streaming Mistral Text prompts 2024-08-13 20:29:24 -05:00
nai-degen 6c5fed90e2 rename function 2024-08-13 20:15:14 -05:00
nai-degen 9479fa4ab0 serviceinfo tweak 2024-08-13 20:13:46 -05:00
nai-degen e145f5757e implements aws mistral streaming 2024-08-13 20:04:07 -05:00
nai-degen 2fe6e07cf5 error better 2024-08-12 20:49:21 -05:00
nai-degen bc340c1be6 non-streaming aws mistral works 2024-08-12 20:37:14 -05:00
nai-degen 45c5d3d338 fixes aws mistral keychecker model invocation 2024-08-12 19:32:26 -05:00
nai-degen 3032ae3198 express route matching is a pain in the ass 2024-08-12 19:31:53 -05:00
nai-degen 49a89122f5 fixes aws models endpoint 2024-08-12 19:26:55 -05:00
nai-degen 2d8e1dac13 adds /aws/mistral endpoint 2024-08-12 19:10:49 -05:00
nai-degen 9e5a660ef5 refactors aws endpoint router to split claude/mistral 2024-08-12 19:10:49 -05:00
nai-degen 6cf8c09fad removes 'server greeting' header from info page 2024-08-12 19:10:49 -05:00
nai-degen dc1b573020 small KeyProvider#get refactor 2024-08-12 19:10:49 -05:00
nai-degen 3ff771d945 fix gcp rebase issue 2024-08-12 19:10:49 -05:00
nai-degen 985035fe80 adds old test script to repo 2024-08-12 19:10:49 -05:00
nai-degen 442f9529de comments 2024-08-12 19:10:49 -05:00
nai-degen 598ac8e4e1 tries to unfuck service info stat aggregation slightly 2024-08-12 19:10:49 -05:00
nai-degen 750dbee483 adds support for non-Anthropic models to AWS key manager 2024-08-12 19:10:49 -05:00
nai-degen a2d64e281e minor KeyProvider#getLockoutPeriod refactor 2024-08-12 19:10:49 -05:00
nai-degen c6467b02f3 adds AWS mistral model families and checker IDs 2024-08-12 19:10:49 -05:00
36 changed files with 296 additions and 513 deletions
+6 -84
View File
@@ -17,6 +17,7 @@
"@smithy/eventstream-serde-node": "^2.1.3",
"@smithy/protocol-http": "^3.2.1",
"@smithy/signature-v4": "^2.1.3",
"@smithy/types": "^2.10.1",
"@smithy/util-utf8": "^2.1.1",
"axios": "^1.7.4",
"better-sqlite3": "^10.0.0",
@@ -51,7 +52,6 @@
"zod-error": "^1.5.0"
},
"devDependencies": {
"@smithy/types": "^3.3.0",
"@types/better-sqlite3": "^7.6.10",
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
@@ -152,17 +152,6 @@
"node": ">=14.0.0"
}
},
"node_modules/@aws-sdk/types/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@aws-sdk/util-utf8-browser": {
"version": "3.259.0",
"resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz",
@@ -1339,17 +1328,6 @@
"tslib": "^2.5.0"
}
},
"node_modules/@smithy/eventstream-codec/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-node": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-node/-/eventstream-serde-node-2.1.3.tgz",
@@ -1363,17 +1341,6 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-node/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-universal": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-2.1.3.tgz",
@@ -1387,17 +1354,6 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-universal/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/is-array-buffer": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.1.1.tgz",
@@ -1421,17 +1377,6 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/protocol-http/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/signature-v4": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-2.1.3.tgz",
@@ -1450,29 +1395,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/signature-v4/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"node_modules/@smithy/types": {
"version": "2.10.1",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.10.1.tgz",
"integrity": "sha512-hjQO+4ru4cQ58FluQvKKiyMsFg0A6iRpGm2kqdH8fniyNd2WyanoOsYJfMX/IFLuLxEoW6gnRkNZy1y6fUUhtA==",
"dependencies": {
"tslib": "^2.6.2"
"tslib": "^2.5.0"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/types": {
"version": "3.3.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-3.3.0.tgz",
"integrity": "sha512-IxvBBCTFDHbVoK7zIxqA1ZOdc4QfM5HM7rGleCuHi7L1wnKv5Pn69xXJQ9hgxH60ZVygH9/JG0jRgtUncE3QUA==",
"dev": true,
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=16.0.0"
}
},
"node_modules/@smithy/util-buffer-from": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.1.1.tgz",
@@ -1508,17 +1441,6 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/util-middleware/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/util-uri-escape": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-2.1.1.tgz",
+1 -1
View File
@@ -26,6 +26,7 @@
"@smithy/eventstream-serde-node": "^2.1.3",
"@smithy/protocol-http": "^3.2.1",
"@smithy/signature-v4": "^2.1.3",
"@smithy/types": "^2.10.1",
"@smithy/util-utf8": "^2.1.1",
"axios": "^1.7.4",
"better-sqlite3": "^10.0.0",
@@ -60,7 +61,6 @@
"zod-error": "^1.5.0"
},
"devDependencies": {
"@smithy/types": "^3.3.0",
"@types/better-sqlite3": "^7.6.10",
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
+2 -3
View File
@@ -17,7 +17,7 @@ import {
} from "../../shared/users/schema";
import { getLastNImages } from "../../shared/file-storage/image-history";
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
import { invalidatePowChallenges } from "../../user/web/pow-captcha";
import { invalidatePowHmacKey } from "../../user/web/pow-captcha";
const router = Router();
@@ -323,7 +323,7 @@ router.post("/maintenance", (req, res) => {
user.disabledReason = "Admin forced expiration.";
userStore.upsertUser(user);
});
invalidatePowChallenges();
invalidatePowHmacKey();
flash.type = "success";
flash.message = `${temps.length} temporary users marked for expiration.`;
break;
@@ -348,7 +348,6 @@ router.post("/maintenance", (req, res) => {
throw new HttpError(400, "Invalid difficulty" + selected);
}
config.powDifficultyLevel = selected;
invalidatePowChallenges();
break;
}
case "generateTempIpReport": {
+5 -5
View File
@@ -415,18 +415,18 @@ export const config: Config = {
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
textModelRateLimit: getEnvWithDefault("TEXT_MODEL_RATE_LIMIT", 4),
imageModelRateLimit: getEnvWithDefault("IMAGE_MODEL_RATE_LIMIT", 4),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 32768),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 16384),
maxContextTokensAnthropic: getEnvWithDefault(
"MAX_CONTEXT_TOKENS_ANTHROPIC",
32768
0
),
maxOutputTokensOpenAI: getEnvWithDefault(
["MAX_OUTPUT_TOKENS_OPENAI", "MAX_OUTPUT_TOKENS"],
1024
400
),
maxOutputTokensAnthropic: getEnvWithDefault(
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
1024
400
),
allowedModelFamilies: getEnvWithDefault(
"ALLOWED_MODEL_FAMILIES",
@@ -519,7 +519,7 @@ function generateSigningKey() {
}
const signingKey = generateSigningKey();
export const SECRET_SIGNING_KEY = signingKey;
export const COOKIE_SECRET = signingKey;
export async function assertConfigIsValid() {
if (process.env.MODEL_RATE_LIMIT !== undefined) {
+68 -31
View File
@@ -46,7 +46,7 @@ const getModelsResponse = () => {
"claude-3-haiku-20240307",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20240620"
];
const models = claudeVariants.map((id) => ({
@@ -70,7 +70,7 @@ const handleModelRequest: RequestHandler = (_req, res) => {
};
/** Only used for non-streaming requests. */
const anthropicBlockingResponseHandler: ProxyResHandlerWithBody = async (
const anthropicResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
@@ -179,28 +179,6 @@ export function transformAnthropicChatResponseToOpenAI(
};
}
/**
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
* model, reassigns it to Claude 3 Sonnet.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
if (!model.startsWith("gpt-")) return;
req.body.model = "claude-3-sonnet-20240229";
}
/**
* If client requests more than 4096 output tokens the request must have a
* particular version header.
* https://docs.anthropic.com/en/release-notes/api#july-15th-2024
*/
function setAnthropicBetaHeader(req: Request) {
const { max_tokens_to_sample } = req.body;
if (max_tokens_to_sample > 4096) {
req.headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15";
}
}
const anthropicProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.anthropic.com",
@@ -211,7 +189,7 @@ const anthropicProxy = createQueueMiddleware({
proxyReq: createOnProxyReqHandler({
pipeline: [addKey, addAnthropicPreamble, finalizeBody],
}),
proxyRes: createOnProxyResHandler([anthropicBlockingResponseHandler]),
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
error: handleProxyError,
},
// Abusing pathFilter to rewrite the paths dynamically.
@@ -235,11 +213,6 @@ const anthropicProxy = createQueueMiddleware({
}),
});
const nativeAnthropicChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "anthropic" },
{ afterTransform: [setAnthropicBetaHeader] }
);
const nativeTextPreprocessor = createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-text",
@@ -295,7 +268,11 @@ anthropicRouter.get("/v1/models", handleModelRequest);
anthropicRouter.post(
"/v1/messages",
ipLimiter,
nativeAnthropicChatPreprocessor,
createPreprocessorMiddleware({
inApi: "anthropic-chat",
outApi: "anthropic-chat",
service: "anthropic",
}),
anthropicProxy
);
// Anthropic text completion endpoint. Translates to Anthropic chat completion
@@ -315,5 +292,65 @@ anthropicRouter.post(
preprocessOpenAICompatRequest,
anthropicProxy
);
// Temporarily force Anthropic Text to Anthropic Chat for frontends which do not
// yet support the new model. Forces claude-3. Will be removed once common
// frontends have been updated.
anthropicRouter.post(
"/v1/:type(sonnet|opus)/:action(complete|messages)",
ipLimiter,
handleAnthropicTextCompatRequest,
createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-chat",
service: "anthropic",
}),
anthropicProxy
);
function handleAnthropicTextCompatRequest(
req: Request,
res: Response,
next: any
) {
const type = req.params.type;
const action = req.params.action;
const alreadyInChatFormat = Boolean(req.body.messages);
const compatModel = `claude-3-${type}-20240229`;
req.log.info(
{ type, inputModel: req.body.model, compatModel, alreadyInChatFormat },
"Handling Anthropic compatibility request"
);
if (action === "messages" || alreadyInChatFormat) {
return sendErrorToClient({
req,
res,
options: {
title: "Unnecessary usage of compatibility endpoint",
message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/anthropic\` proxy endpoint instead.`,
format: "unknown",
statusCode: 400,
reqId: req.id,
obj: {
requested_endpoint: "/anthropic/" + type,
correct_endpoint: "/anthropic",
},
},
});
}
req.body.model = compatModel;
next();
}
/**
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
* model, reassigns it to Claude 3 Sonnet.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
if (!model.startsWith("gpt-")) return;
req.body.model = "claude-3-sonnet-20240229";
}
export const anthropic = anthropicRouter;
+4 -1
View File
@@ -1,5 +1,6 @@
import { Request, RequestHandler, Router } from "express";
import { Request, RequestHandler, Response, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
@@ -16,6 +17,8 @@ import {
createOnProxyResHandler,
} from "./middleware/response";
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
import { sendErrorToClient } from "./middleware/response/error-generator";
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
let modelsCache: any = null;
+1 -1
View File
@@ -152,7 +152,7 @@ googleAIRouter.post(
outApi: "google-ai",
service: "google-ai",
},
{ beforeTransform: [maybeReassignModel], afterTransform: [setStreamFlag] }
{ afterTransform: [maybeReassignModel, setStreamFlag] }
),
googleAIProxy
);
-2
View File
@@ -16,7 +16,6 @@ const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
const ANTHROPIC_MESSAGES_ENDPOINT = "/v1/messages";
const ANTHROPIC_SONNET_COMPAT_ENDPOINT = "/v1/sonnet";
const ANTHROPIC_OPUS_COMPAT_ENDPOINT = "/v1/opus";
const GOOGLE_AI_COMPLETION_ENDPOINT = "/v1beta/models";
export function isTextGenerationRequest(req: Request) {
return (
@@ -28,7 +27,6 @@ export function isTextGenerationRequest(req: Request) {
ANTHROPIC_MESSAGES_ENDPOINT,
ANTHROPIC_SONNET_COMPAT_ENDPOINT,
ANTHROPIC_OPUS_COMPAT_ENDPOINT,
GOOGLE_AI_COMPLETION_ENDPOINT,
].some((endpoint) => req.path.startsWith(endpoint))
);
}
@@ -1,16 +1,14 @@
import { HPMRequestCallback } from "../index";
import { config } from "../../../../config";
import { ForbiddenError } from "../../../../shared/errors";
import { getModelFamilyForRequest } from "../../../../shared/models";
import { HPMRequestCallback } from "../index";
/**
* Ensures the selected model family is enabled by the proxy configuration.
*/
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req) => {
**/
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req, res) => {
const family = getModelFamilyForRequest(req);
if (!config.allowedModelFamilies.includes(family)) {
throw new ForbiddenError(
`Model family '${family}' is not enabled on this proxy`
);
throw new ForbiddenError(`Model family '${family}' is not enabled on this proxy`);
}
};
@@ -84,7 +84,7 @@ async function executePreprocessors(
} catch (error) {
if (error.constructor.name === "ZodError") {
const msg = error?.issues
?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`)
?.map((issue: ZodIssue) => issue.message)
.join("; ");
req.log.warn({ issues: msg }, "Prompt validation failed.");
} else {
@@ -30,13 +30,10 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
}
case "anthropic-chat": {
req.outputTokens = req.body.max_tokens;
let system = req.body.system ?? "";
if (Array.isArray(system)) {
system = system
.map((m: { type: string; text: string }) => m.text)
.join("\n");
}
const prompt = { system, messages: req.body.messages };
const prompt = {
system: req.body.system ?? "",
messages: req.body.messages,
};
result = await countTokens({ req, prompt, service });
break;
}
@@ -6,7 +6,7 @@ import {
AnthropicV1TextSchema,
AnthropicV1MessagesSchema,
} from "../../../../shared/api-schemas";
import { AwsBedrockKey, keyPool } from "../../../../shared/key-management";
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import {
AWSMistralV1ChatCompletionsSchema,
@@ -35,26 +35,17 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
const credential = getCredentialParts(req);
const host = AMZ_HOST.replace("%REGION%", credential.region);
// AWS only uses 2023-06-01 and does not actually check this header, but we
// set it so that the stream adapter always selects the correct transformer.
req.headers["anthropic-version"] = "2023-06-01";
// If our key has an inference profile compatible with the requested model,
// we want to use the inference profile instead of the model ID when calling
// InvokeModel as that will give us higher rate limits.
const profile =
(req.key as AwsBedrockKey).inferenceProfileIds.find((p) =>
p.includes(model)
) || model;
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
// with the headers generated by the SDK.
const newRequest = new HttpRequest({
method: "POST",
protocol: "https:",
hostname: host,
path: `/model/${profile}/invoke${stream ? "-with-response-stream" : ""}`,
path: `/model/${model}/invoke${stream ? "-with-response-stream" : ""}`,
headers: {
["Host"]: host,
["content-type"]: "application/json",
@@ -70,13 +61,7 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
const { key, body, inboundApi, outboundApi } = req;
req.log.info(
{
key: key.hash,
model: body.model,
inferenceProfile: profile,
inboundApi,
outboundApi,
},
{ key: key.hash, model: body.model, inboundApi, outboundApi },
"Assigned AWS credentials to request"
);
@@ -144,8 +129,6 @@ function applyAwsStrictValidation(req: Request): unknown {
temperature: true,
top_k: true,
top_p: true,
tools: true,
tool_choice: true,
})
.strip()
.parse(req.body);
@@ -24,6 +24,7 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
req.isStreaming = String(stream) === "true";
// TODO: This should happen in transform-outbound-payload.ts
// TODO: Support tools
let strippedParams: Record<string, unknown>;
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
@@ -33,8 +34,6 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
temperature: true,
top_k: true,
top_p: true,
tools: true,
tool_choice: true,
stream: true,
})
.strip()
@@ -17,17 +17,7 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
const notTransformable =
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
if (alreadyTransformed) {
return;
} else if (notTransformable) {
// This is probably an indication of a bug in the proxy.
const { inboundApi, outboundApi, method, path } = req;
req.log.warn(
{ inboundApi, outboundApi, method, path },
"`transformOutboundPayload` called on a non-transformable request."
);
return;
}
if (alreadyTransformed || notTransformable) return;
applyMistralPromptFixes(req);
@@ -77,13 +67,11 @@ function applyMistralPromptFixes(req: Request): void {
);
// If the prompt relies on `prefix: true` for the last message, we need to
// convert it to a text completions request because AWS Mistral support for
// this feature is broken.
// On Mistral La Plateforme, we can't do this because they don't expose
// a text completions endpoint.
// convert it to a text completions request because Mistral support for
// this feature is limited (and completely broken on AWS Mistral).
const { messages } = req.body;
const lastMessage = messages && messages[messages.length - 1];
if (lastMessage?.role === "assistant" && req.service === "aws") {
if (lastMessage && lastMessage.role === "assistant") {
// enable prefix if client forgot, otherwise the template will insert an
// eos token which is very unlikely to be what the client wants.
lastMessage.prefix = true;
@@ -58,8 +58,6 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 16384;
} else if (model.match(/^gpt-4o/)) {
modelMax = 128000;
} else if (model.match(/^chatgpt-4o/)) {
modelMax = 128000;
} else if (model.match(/gpt-4-turbo(-\d{4}-\d{2}-\d{2})?$/)) {
modelMax = 131072;
} else if (model.match(/gpt-4-turbo(-preview)?$/)) {
+1 -7
View File
@@ -75,13 +75,7 @@ const getPromptForRequest = (
case "mistral-ai":
return req.body.messages;
case "anthropic-chat":
let system = req.body.system;
if (Array.isArray(system)) {
system = system
.map((m: { type: string; text: string }) => m.text)
.join("\n");
}
return { system, messages: req.body.messages };
return { system: req.body.system, messages: req.body.messages };
case "openai-text":
case "anthropic-text":
case "mistral-text":
+1 -7
View File
@@ -21,7 +21,6 @@ import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { BadRequestError } from "../shared/errors";
// Mistral can't settle on a single naming scheme and deprecates models within
// months of releasing them so this list is hard to keep up to date. 2024-07-28
@@ -171,12 +170,7 @@ export function detectMistralInputApi(req: Request) {
if (messages) {
req.inboundApi = "mistral-ai";
req.outboundApi = "mistral-ai";
} else if (prompt && req.service === "mistral-ai") {
// Mistral La Plateforme doesn't expose a text completions endpoint.
throw new BadRequestError(
"Mistral (via La Plateforme API) does not support text completions. This format is only supported on Mistral via the AWS API."
);
} else if (prompt && req.service === "aws") {
} else if (prompt) {
req.inboundApi = "mistral-text";
req.outboundApi = "mistral-text";
}
-2
View File
@@ -35,8 +35,6 @@ export const KNOWN_OPENAI_MODELS = [
// GPT4o Mini
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
// GPT4o (ChatGPT)
"chatgpt-4o-latest",
// GPT4 Turbo (superceded by GPT4o)
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09", // gpt4-turbo stable, with vision
+48 -12
View File
@@ -22,7 +22,7 @@ import {
} from "../shared/models";
import { initializeSseStream } from "../shared/streaming";
import { logger } from "../logger";
import { getUniqueIps } from "./rate-limit";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
import { handleProxyError } from "./middleware/common";
import { sendErrorToClient } from "./middleware/response/error-generator";
@@ -31,9 +31,7 @@ const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
/** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = parseInt(
process.env.USER_CONCURRENCY_LIMIT ?? "1"
);
const USER_CONCURRENCY_LIMIT = parseInt(process.env.USER_CONCURRENCY_LIMIT ?? "1");
/** Maximum number of queue slots for Agnai.chat requests. */
const AGNAI_CONCURRENCY_LIMIT = USER_CONCURRENCY_LIMIT * 5;
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
@@ -60,20 +58,39 @@ const QUEUE_JOIN_TIMEOUT = 5000;
function getIdentifier(req: Request) {
if (req.user) return req.user.token;
if (req.risuToken) return req.risuToken;
// if (isFromSharedIp(req)) return "shared-ip";
if (isFromSharedIp(req)) return "shared-ip";
return req.ip;
}
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
getIdentifier(queued) === getIdentifier(incoming);
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
async function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
let isGuest = req.user?.token === undefined;
if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT) {
throw new TooManyRequestsError(
"Your IP or user token already has another request in the queue."
);
// Requests from shared IP addresses such as Agnai.chat are exempt from IP-
// based rate limiting but can only occupy a certain number of slots in the
// queue. Authenticated users always get a single spot in the queue.
const isSharedIp = isFromSharedIp(req);
const maxConcurrentQueuedRequests =
isGuest && isSharedIp ? AGNAI_CONCURRENCY_LIMIT : USER_CONCURRENCY_LIMIT;
if (enqueuedRequestCount >= maxConcurrentQueuedRequests) {
if (isSharedIp) {
// Re-enqueued requests are not counted towards the limit since they
// already made it through the queue once.
if (req.retryCount === 0) {
throw new TooManyRequestsError(
"Too many agnai.chat requests are already queued"
);
}
} else {
throw new TooManyRequestsError(
"Your IP or user token already has another request in the queue."
);
}
}
// shitty hack to remove hpm's event listeners on retried requests
@@ -129,7 +146,19 @@ export async function reenqueueRequest(req: Request) {
}
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue.filter((req) => getModelFamilyForRequest(req) === partition);
return queue
.filter((req) => getModelFamilyForRequest(req) === partition)
.sort((a, b) => {
// Certain requests are exempted from IP-based rate limiting because they
// come from a shared IP address. To prevent these requests from starving
// out other requests during periods of high traffic, we sort them to the
// end of the queue.
const aIsExempted = isFromSharedIp(a);
const bIsExempted = isFromSharedIp(b);
if (aIsExempted && !bIsExempted) return 1;
if (!aIsExempted && bIsExempted) return -1;
return 0;
});
}
export function dequeue(partition: ModelFamily): Request | undefined {
@@ -232,6 +261,7 @@ let waitTimes: {
partition: ModelFamily;
start: number;
end: number;
isDeprioritized: boolean;
}[] = [];
/** Adds a successful request to the list of wait times. */
@@ -240,6 +270,7 @@ export function trackWaitTime(req: Request) {
partition: getModelFamilyForRequest(req),
start: req.startTime!,
end: req.queueOutTime ?? Date.now(),
isDeprioritized: isFromSharedIp(req),
});
}
@@ -265,7 +296,8 @@ function calculateWaitTime(partition: ModelFamily) {
.filter((wait) => {
const isSamePartition = wait.partition === partition;
const isRecent = now - wait.end < 300 * 1000;
return isSamePartition && isRecent;
const isNormalPriority = !wait.isDeprioritized;
return isSamePartition && isRecent && isNormalPriority;
})
.map((wait) => wait.end - wait.start);
const recentAverage = recentWaits.length
@@ -279,7 +311,11 @@ function calculateWaitTime(partition: ModelFamily) {
);
const currentWaits = queue
.filter((req) => getModelFamilyForRequest(req) === partition)
.filter((req) => {
const isSamePartition = getModelFamilyForRequest(req) === partition;
const isNormalPriority = !isFromSharedIp(req);
return isSamePartition && isNormalPriority;
})
.map((req) => now - req.startTime!);
const longestCurrentWait = Math.max(...currentWaits, 0);
+32 -15
View File
@@ -1,6 +1,14 @@
import { Request, Response, NextFunction } from "express";
import { config } from "../config";
export const SHARED_IP_ADDRESSES = new Set([
// Agnai.chat
"157.230.249.32", // old
"157.245.148.56",
"174.138.29.50",
"209.97.162.44",
]);
const ONE_MINUTE_MS = 60 * 1000;
type Timestamp = number;
@@ -12,10 +20,7 @@ const exemptedRequests: Timestamp[] = [];
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
attempt > now - ONE_MINUTE_MS;
/**
* Returns duration in seconds to wait before retrying for Retry-After header.
*/
const getRetryAfter = (ip: string, type: "text" | "image") => {
const getTryAgainInMs = (ip: string, type: "text" | "image") => {
const now = Date.now();
const attempts = lastAttempts.get(ip) || [];
const validAttempts = attempts.filter(isRecentAttempt(now));
@@ -24,7 +29,7 @@ const getRetryAfter = (ip: string, type: "text" | "image") => {
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
if (validAttempts.length >= limit) {
return (validAttempts[0] - now + ONE_MINUTE_MS) / 1000;
return validAttempts[0] - now + ONE_MINUTE_MS;
} else {
lastAttempts.set(ip, [...validAttempts, now]);
return 0;
@@ -91,11 +96,22 @@ export const ipLimiter = async (
if (!textLimit && !imageLimit) return next();
if (req.user?.type === "special") return next();
const path = req.baseUrl + req.path;
const type =
path.includes("openai-image") || path.includes("images/generations")
? "image"
: "text";
// Exempts Agnai.chat from IP-based rate limiting because its IPs are shared
// by many users. Instead, the request queue will limit the number of such
// requests that may wait in the queue at a time, and sorts them to the end to
// let individual users go first.
if (SHARED_IP_ADDRESSES.has(req.ip)) {
exemptedRequests.push(Date.now());
req.log.info(
{ ip: req.ip, recentExemptions: exemptedRequests.length },
"Exempting Agnai request from rate limiting."
);
return next();
}
const type = (req.baseUrl + req.path).includes("openai-image")
? "image"
: "text";
const limit = type === "image" ? imageLimit : textLimit;
// If user is authenticated, key rate limiting by their token. Otherwise, key
@@ -107,14 +123,15 @@ export const ipLimiter = async (
res.set("X-RateLimit-Remaining", remaining.toString());
res.set("X-RateLimit-Reset", reset.toString());
const retryAfterTime = getRetryAfter(rateLimitKey, type);
if (retryAfterTime > 0) {
const waitSec = Math.ceil(retryAfterTime).toString();
res.set("Retry-After", waitSec);
const tryAgainInMs = getTryAgainInMs(rateLimitKey, type);
if (tryAgainInMs > 0) {
res.set("Retry-After", tryAgainInMs.toString());
res.status(429).json({
error: {
type: "proxy_rate_limited",
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${waitSec} seconds.`,
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil(
tryAgainInMs / 1000
)} seconds.`,
},
});
} else {
-1
View File
@@ -49,7 +49,6 @@ app.use(
// Don't log the prompt text on transform errors
"body.messages",
"body.prompt",
"body.contents",
],
censor: "********",
},
+1 -23
View File
@@ -19,12 +19,7 @@ const AnthropicV1BaseSchema = z
top_k: z.coerce.number().optional(),
top_p: z.coerce.number().optional(),
metadata: z.object({ user_id: z.string().optional() }).optional(),
tools: z.array(z.any()).optional(),
tool_choice: z.any().optional(),
})
.omit(
Boolean(config.allowOpenAIToolUsage) ? {} : { tools: true, tool_choice: true }
)
.strip();
// https://docs.anthropic.com/claude/reference/complete_post [deprecated]
@@ -49,18 +44,6 @@ const AnthropicV1MessageMultimodalContentSchema = z.array(
data: z.string(),
}),
}),
z.object({
type: z.literal("tool_use"),
id: z.string(),
name: z.string(),
input: z.object({}).passthrough(),
}),
z.object({
type: z.literal("tool_result"),
tool_use_id: z.string(),
is_error: z.boolean().optional(),
content: z.union([z.string(), z.object({}).passthrough()]).optional(),
}),
])
);
@@ -80,12 +63,7 @@ export const AnthropicV1MessagesSchema = AnthropicV1BaseSchema.merge(
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
system: z
.union([
z.string(),
z.array(z.object({ type: z.literal("text"), text: z.string() })),
])
.optional(),
system: z.string().optional(),
})
);
export type AnthropicChatMessage = z.infer<
+1 -1
View File
@@ -31,7 +31,7 @@ export const GoogleAIV1GenerateContentSchema = z
topP: z.number().optional(),
topK: z.number().optional(),
stopSequences: z.array(z.string().max(500)).max(5).optional(),
}).default({}),
}),
})
.strip();
export type GoogleAIChatMessage = z.infer<
+1 -3
View File
@@ -45,9 +45,7 @@ const BaseMistralAIV1CompletionsSchema = z.object({
.default([])
.transform((v) => (Array.isArray(v) ? v : [v])),
random_seed: z.number().int().min(0).optional(),
response_format: z
.object({ type: z.enum(["text", "json_object"]) })
.optional(),
response_format: z.enum(["text", "json_object"]).optional().default("text"),
safe_prompt: z.boolean().optional().default(false),
});
+1 -1
View File
@@ -52,7 +52,7 @@ export const OpenAIV1ChatCompletionSchema = z
.number()
.int()
.nullish()
.default(Math.min(OPENAI_OUTPUT_MAX, 16384))
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
-18
View File
@@ -1,18 +0,0 @@
/** Module for generating and verifying HMAC signatures. */
import crypto from "crypto";
import { SECRET_SIGNING_KEY } from "../config";
/**
* Generates a HMAC signature for the given message. Optionally salts the
* key with a provided string.
*/
export function signMessage(msg: any, salt: string = ""): string {
const hmac = crypto.createHmac("sha256", SECRET_SIGNING_KEY + salt);
if (typeof msg === "object") {
hmac.update(JSON.stringify(msg));
} else {
hmac.update(msg);
}
return hmac.digest("hex");
}
+2 -2
View File
@@ -1,9 +1,9 @@
import { doubleCsrf } from "csrf-csrf";
import express from "express";
import { config, SECRET_SIGNING_KEY } from "../config";
import { config, COOKIE_SECRET } from "../config";
const { generateToken, doubleCsrfProtection } = doubleCsrf({
getSecret: () => SECRET_SIGNING_KEY,
getSecret: () => COOKIE_SECRET,
cookieName: "csrf",
cookieOptions: {
sameSite: "strict",
+39 -129
View File
@@ -1,12 +1,12 @@
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import axios, { AxiosError, AxiosHeaders, AxiosRequestConfig } from "axios";
import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
import { URL } from "url";
import { config } from "../../../config";
import { getAwsBedrockModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base";
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
import { getAwsBedrockModelFamily } from "../../models";
import { config } from "../../../config";
type ParentModelId = string;
type AliasModelId = string;
@@ -24,8 +24,6 @@ const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [
["mistral.mistral-large-2407-v1:0"],
["mistral.mistral-small-2402-v1:0"], // Seems to return 400
];
const KEY_CHECK_BATCH_SIZE = 2; // AWS checker needs to do lots of concurrent requests so should lower the batch size
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const AMZ_HOST =
@@ -33,8 +31,6 @@ const AMZ_HOST =
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
const GET_LIST_INFERENCE_PROFILES_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/inference-profiles?maxResults=1000`;
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
const TEST_MESSAGES = [
@@ -44,22 +40,6 @@ const TEST_MESSAGES = [
type AwsError = { error: {} };
type GetInferenceProfilesResponse = {
inferenceProfileSummaries: {
inferenceProfileId: string;
inferenceProfileName: string;
inferenceProfileArn: string;
description?: string;
createdAt?: string;
updatedAt?: string;
status: "ACTIVE" | unknown;
type: "SYSTEM_DEFINED" | unknown;
models: {
modelArn?: string;
}[];
}[];
};
type GetLoggingConfigResponse = {
loggingConfig: null | {
cloudWatchConfig: null | unknown;
@@ -78,7 +58,6 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
service: "aws",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
keyCheckBatchSize: KEY_CHECK_BATCH_SIZE,
updateKey,
});
}
@@ -87,51 +66,37 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
try {
await this.checkInferenceProfiles(key);
} catch (e) {
const asError = e as AxiosError<AwsError>;
const data = asError.response?.data;
this.log.warn(
{ key: key.hash, error: e.message, data },
"Cannot list inference profiles.\n\
Principal may be missing `AmazonBedrockFullAccess`, or has no policy allowing action `bedrock:ListInferenceProfiles` against resource `arn:aws:bedrock:*:*:inference-profile/*`.\n\
Requests will be made without inference profiles using on-demand quotas, which may be subject to more restrictive rate limits.\n\
See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-prereq.html."
);
}
}
// Perform checks for all parent model IDs
const results = await Promise.all(
KNOWN_MODEL_IDS.filter(([model]) =>
// Skip checks for models that are disabled anyway
config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model))
).map(async ([model, ...aliases]) => ({
models: [model, ...aliases],
success: await this.invokeModel(model, key),
}))
);
// Filter out models that are disabled
const modelIds = results
.filter(({ success }) => success)
.flatMap(({ models }) => models);
if (modelIds.length === 0) {
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
// Perform checks for all parent model IDs
const results = await Promise.all(
KNOWN_MODEL_IDS.filter(([model]) =>
// Skip checks for models that are disabled anyway
config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model))
).map(async ([model, ...aliases]) => ({
models: [model, ...aliases],
success: await this.invokeModel(model, key),
}))
);
return this.updateKey(key.hash, { isDisabled: true });
}
this.updateKey(key.hash, {
modelIds,
modelFamilies: Array.from(
new Set(modelIds.map(getAwsBedrockModelFamily))
),
});
// Filter out models that are disabled
const modelIds = results
.filter(({ success }) => success)
.flatMap(({ models }) => models);
if (modelIds.length === 0) {
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
);
return this.updateKey(key.hash, { isDisabled: true });
}
this.updateKey(key.hash, {
modelIds,
modelFamilies: Array.from(
new Set(modelIds.map(getAwsBedrockModelFamily))
),
});
}
this.log.info(
{
@@ -214,36 +179,6 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
key: AwsBedrockKey
): Promise<boolean> {
if (model.includes("claude")) {
// If inference profiles are available, try testing model with them.
// If they are not available or the invocation fails with the inference
// profile, fall back to regular model ID.
const { region } = AwsKeyChecker.getCredentialsFromKey(key);
const continent = region.split("-")[0];
const profile = key.inferenceProfileIds.find(
(id) => `${continent}.${model}` === id
);
if (profile) {
this.log.debug(
{ key: key.hash, model, profile },
"Testing model via inference profile."
);
let result: boolean;
try {
result = await this.testClaudeModel(key, profile);
} catch (e) {
this.log.error(
{ key: key.hash, model, profile, error: e.message },
"Error testing model with inference profile; trying model ID directly."
);
result = false;
}
// If the profile worked, we'll return success. Caller will add the
// model (not the profile) to the list of enabled models, but the
// profile will be used when the key is used for inference.
if (result) return true;
}
return this.testClaudeModel(key, model);
} else if (model.includes("mistral")) {
return this.testMistralModel(key, model);
@@ -287,10 +222,6 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
status === 403 &&
errorMessage?.match(/access to the model with the specified model ID/)
) {
this.log.debug(
{ key: key.hash, model, errorType, data, status, headers },
"Model is not available (principal does not have access)."
);
return false;
}
@@ -299,7 +230,7 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
if (status === 404) {
this.log.debug(
{ region: creds.region, model, key: key.hash },
"Model is not available (not supported in this AWS region)."
"Model not supported in this AWS region."
);
return false;
}
@@ -311,14 +242,14 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
if (!correctErrorType || !correctErrorMessage) {
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is not available (request rejected)."
"AWS InvokeModel test unsuccessful."
);
return false;
}
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is available."
"AWS InvokeModel test successful."
);
return true;
}
@@ -352,7 +283,7 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
if (status === 403 || status === 404) {
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is not available (no access or unsupported region)."
"AWS InvokeModel test returned 403 or 404."
);
return false;
}
@@ -362,38 +293,18 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
if (isBadRequest && !isValidationError) {
this.log.debug(
{ key: key.hash, model, errorType, data, status, headers },
"Model is not available (request rejected)."
"AWS InvokeModel test returned 400 but not a validation error."
);
return false;
}
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is available."
"AWS InvokeModel test successful."
);
return true;
}
private async checkInferenceProfiles(key: AwsBedrockKey) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const req: AxiosRequestConfig = {
method: "GET",
url: GET_LIST_INFERENCE_PROFILES_URL(creds.region),
headers: { accept: "application/json" },
};
await AwsKeyChecker.signRequestForAws(req, key);
const { data } = await axios.request<GetInferenceProfilesResponse>(req);
const { inferenceProfileSummaries } = data;
const profileIds = inferenceProfileSummaries.map(
(p) => p.inferenceProfileId
);
this.log.debug(
{ key: key.hash, profileIds, region: creds.region },
"Inference profiles found."
);
this.updateKey(key.hash, { inferenceProfileIds: profileIds });
}
private async checkLoggingConfiguration(key: AwsBedrockKey) {
if (config.allowAwsLogging) {
// Don't check logging status if we're allowing it to reduce API calls.
@@ -462,8 +373,7 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
method,
protocol: "https:",
hostname: url.hostname,
path: url.pathname,
query: Object.fromEntries(url.searchParams),
path: url.pathname + url.search,
headers: { Host: url.hostname, ...plainHeaders },
});
+1 -17
View File
@@ -22,7 +22,6 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
*/
awsLoggingStatus: "unknown" | "disabled" | "enabled";
modelIds: string[];
inferenceProfileIds: string[];
}
/**
@@ -73,7 +72,6 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
.slice(0, 8)}`,
lastChecked: 0,
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
inferenceProfileIds: [],
["aws-claudeTokens"]: 0,
["aws-claude-opusTokens"]: 0,
["aws-mistral-tinyTokens"]: 0,
@@ -137,21 +135,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
);
}
/**
* Comparator for prioritizing keys on inference profile compatibility.
* Requests made via inference profiles have higher rate limits so we want
* to use keys with compatible inference profiles first.
*/
const hasInferenceProfile = (
a: AwsBedrockKey,
b: AwsBedrockKey
) => {
const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model));
const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model));
return aMatch - bMatch;
};
const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0];
const selectedKey = prioritizeKeys(availableKeys)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
+34 -51
View File
@@ -6,12 +6,10 @@ import { GcpModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
const GCP_HOST =
process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
`https://${GCP_HOST.replace(
"%REGION%",
region
)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
`https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
const TEST_MESSAGES = [
{ role: "user", content: "Hi!" },
{ role: "assistant", content: "Hello!" },
@@ -25,7 +23,6 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
service: "gcp",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
@@ -41,8 +38,9 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
];
const [sonnet, haiku, opus, sonnet35] = await Promise.all(checks);
const [sonnet, haiku, opus, sonnet35] =
await Promise.all(checks);
this.log.debug(
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
"GCP model initial tests complete."
@@ -68,17 +66,20 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
});
} else {
if (key.haikuEnabled) {
await this.invokeModel("claude-3-haiku@20240307", key, false);
await this.invokeModel("claude-3-haiku@20240307", key, false)
} else if (key.sonnetEnabled) {
await this.invokeModel("claude-3-sonnet@20240229", key, false);
await this.invokeModel("claude-3-sonnet@20240229", key, false)
} else if (key.sonnet35Enabled) {
await this.invokeModel("claude-3-5-sonnet@20240620", key, false);
await this.invokeModel("claude-3-5-sonnet@20240620", key, false)
} else {
await this.invokeModel("claude-3-opus@20240229", key, false);
await this.invokeModel("claude-3-opus@20240229", key, false)
}
this.updateKey(key.hash, { lastChecked: Date.now() });
this.log.debug({ key: key.hash }, "GCP key check complete.");
this.log.debug(
{ key: key.hash},
"GCP key check complete."
);
}
this.log.info(
@@ -133,12 +134,8 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
*/
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
const creds = GcpKeyChecker.getCredentialsFromKey(key);
const signedJWT = await GcpKeyChecker.createSignedJWT(
creds.clientEmail,
creds.privateKey
);
const [accessToken, jwtError] =
await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT);
const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey)
const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT)
if (accessToken === null) {
this.log.warn(
{ key: key.hash, jwtError },
@@ -154,19 +151,15 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
const { data, status } = await axios.post(
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
payload,
{
{
headers: GcpKeyChecker.getRequestHeaders(accessToken),
validateStatus: initial
? () => true
: (status: number) => status >= 200 && status < 300,
validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300
}
);
this.log.debug({ key: key.hash, data }, "Response from GCP");
if (initial) {
return (
(status >= 200 && status < 300) || status === 429 || status === 529
);
return (status >= 200 && status < 300) || (status === 429 || status === 529);
}
return true;
@@ -185,7 +178,10 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
let cryptoKey = await crypto.subtle.importKey(
"pkcs8",
GcpKeyChecker.str2ab(atob(pkey)),
{ name: "RSASSA-PKCS1-v1_5", hash: { name: "SHA-256" } },
{
name: "RSASSA-PKCS1-v1_5",
hash: { name: "SHA-256" },
},
false,
["sign"]
);
@@ -194,7 +190,10 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = { alg: "RS256", typ: "JWT" };
const header = {
alg: "RS256",
typ: "JWT",
};
const payload = {
iss: email,
@@ -204,12 +203,8 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
scope: "https://www.googleapis.com/auth/cloud-platform",
};
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(header)
);
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(payload)
);
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header));
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload));
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
@@ -223,9 +218,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
return `${unsignedToken}.${encodedSignature}`;
}
static async exchangeJwtForAccessToken(
signed_jwt: string
): Promise<[string | null, string]> {
static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> {
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
@@ -259,11 +252,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16))));
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
@@ -271,10 +260,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
}
static getRequestHeaders(accessToken: string) {
return {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
};
return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" };
}
static getCredentialsFromKey(key: GcpKey) {
@@ -283,12 +269,9 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
throw new Error("Invalid GCP key");
}
const privateKey = rawPrivateKey
.replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '')
.trim();
return { projectId, clientEmail, region, privateKey };
}
}
@@ -7,7 +7,6 @@ type KeyCheckerOptions<TKey extends Key = Key> = {
service: string;
keyCheckPeriod: number;
minCheckInterval: number;
keyCheckBatchSize?: number;
recurringChecksEnabled?: boolean;
updateKey: (hash: string, props: Partial<TKey>) => void;
};
@@ -23,8 +22,6 @@ export abstract class KeyCheckerBase<TKey extends Key> {
* than this.
*/
protected readonly keyCheckPeriod: number;
/** Maximum number of keys to check simultaneously. */
protected readonly keyCheckBatchSize: number;
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
protected readonly keys: TKey[] = [];
protected log: pino.Logger;
@@ -36,7 +33,6 @@ export abstract class KeyCheckerBase<TKey extends Key> {
this.keyCheckPeriod = opts.keyCheckPeriod;
this.minCheckInterval = opts.minCheckInterval;
this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true;
this.keyCheckBatchSize = opts.keyCheckBatchSize ?? 12;
this.updateKey = opts.updateKey;
this.service = opts.service;
this.log = logger.child({ module: "key-checker", service: opts.service });
@@ -82,7 +78,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
if (numUnchecked > 0) {
const keycheckBatch = uncheckedKeys.slice(0, this.keyCheckBatchSize);
const keycheckBatch = uncheckedKeys.slice(0, 12);
this.timeout = setTimeout(async () => {
try {
+7 -22
View File
@@ -1,22 +1,12 @@
import { Key } from "./index";
/**
* Given a list of keys, returns a new list of keys sorted from highest to
* lowest priority. Keys are prioritized in the following order:
*
* 1. Keys which are not rate limited
* a. If all keys were rate limited recently, select the least-recently
* rate limited key.
* b. Otherwise, select the first key.
* 2. Keys which have not been used in the longest time
* 3. Keys according to the custom comparator, if provided
* @param keys The list of keys to sort
* @param customComparator A custom comparator function to use for sorting
*/
export function prioritizeKeys<T extends Key>(
keys: T[],
customComparator?: (a: T, b: T) => number
) {
export function prioritizeKeys<T extends Key>(keys: T[]) {
// Sorts keys from highest priority to lowest priority, where priority is:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 2. Keys which have not been used in the longest time
const now = Date.now();
return keys.sort((a, b) => {
@@ -29,11 +19,6 @@ export function prioritizeKeys<T extends Key>(
return a.rateLimitedAt - b.rateLimitedAt;
}
if (customComparator) {
const result = customComparator(a, b);
if (result !== 0) return result;
}
return a.lastUsed - b.lastUsed;
});
}
-1
View File
@@ -130,7 +130,6 @@ export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4o(-\\d{4}-\\d{2}-\\d{2})?$": "gpt4o",
"^chatgpt-4o": "gpt4o",
"^gpt-4o-mini(-\\d{4}-\\d{2}-\\d{2})?$": "turbo", // closest match
"^gpt-4-turbo(-\\d{4}-\\d{2}-\\d{2})?$": "gpt4-turbo",
"^gpt-4-turbo(-preview)?$": "gpt4-turbo",
-3
View File
@@ -67,9 +67,6 @@ async function getTokenCountForMessages({
case "image":
numTokens += await getImageTokenCount(part.source.data);
break;
case "tool_use":
case "tool_result":
break;
default:
throw new Error(`Unsupported Anthropic content type.`);
}
+3 -3
View File
@@ -1,14 +1,14 @@
import cookieParser from "cookie-parser";
import expressSession from "express-session";
import MemoryStore from "memorystore";
import { config, SECRET_SIGNING_KEY } from "../config";
import { config, COOKIE_SECRET } from "../config";
const ONE_WEEK = 1000 * 60 * 60 * 24 * 7;
const cookieParserMiddleware = cookieParser(SECRET_SIGNING_KEY);
const cookieParserMiddleware = cookieParser(COOKIE_SECRET);
const sessionMiddleware = expressSession({
secret: SECRET_SIGNING_KEY,
secret: COOKIE_SECRET,
resave: false,
saveUninitialized: false,
store: new (MemoryStore(expressSession))({ checkPeriod: ONE_WEEK }),
+19 -8
View File
@@ -2,7 +2,6 @@ import crypto from "crypto";
import express from "express";
import argon2 from "@node-rs/argon2";
import { z } from "zod";
import { signMessage } from "../../shared/hmac-signing";
import {
authenticate,
createUser,
@@ -14,13 +13,15 @@ import { config } from "../../config";
/** Lockout time after verification in milliseconds */
const LOCKOUT_TIME = 1000 * 60; // 60 seconds
let powKeySalt = crypto.randomBytes(32).toString("hex");
/** HMAC key for signing challenges; regenerated on startup */
let hmacSecret = crypto.randomBytes(32).toString("hex");
/**
* Invalidates any outstanding unsolved challenges.
* Regenerate the HMAC key used for signing challenges. Calling this function
* will invalidate all existing challenges.
*/
export function invalidatePowChallenges() {
powKeySalt = crypto.randomBytes(32).toString("hex");
export function invalidatePowHmacKey() {
hmacSecret = crypto.randomBytes(32).toString("hex");
}
const argon2Params = {
@@ -140,6 +141,16 @@ function generateChallenge(clientIp?: string, token?: string): Challenge {
};
}
function signMessage(msg: any): string {
const hmac = crypto.createHmac("sha256", hmacSecret);
if (typeof msg === "object") {
hmac.update(JSON.stringify(msg));
} else {
hmac.update(msg);
}
return hmac.digest("hex");
}
async function verifySolution(
challenge: Challenge,
solution: string,
@@ -214,11 +225,11 @@ router.post("/challenge", (req, res) => {
return;
}
const challenge = generateChallenge(req.ip, refreshToken);
const signature = signMessage(challenge, powKeySalt);
const signature = signMessage(challenge);
res.json({ challenge, signature });
} else {
const challenge = generateChallenge(req.ip);
const signature = signMessage(challenge, powKeySalt);
const signature = signMessage(challenge);
res.json({ challenge, signature });
}
});
@@ -242,7 +253,7 @@ router.post("/verify", async (req, res) => {
}
const { challenge, signature, solution } = result.data;
if (signMessage(challenge, powKeySalt) !== signature) {
if (signMessage(challenge) !== signature) {
res.status(400).json({
error:
"Invalid signature; server may have restarted since challenge was issued. Please request a new challenge.",