Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9e6fd7c24c | |||
| ac92a19946 | |||
| 96fe974ad0 | |||
| 578615fbd2 | |||
| 5dc4050e52 | |||
| cf615ee62c | |||
| ee61f9be2b | |||
| 0c448cb59d | |||
| 51a9ccceb2 | |||
| ce490efd7d | |||
| 5000e59a61 |
Generated
+84
-6
@@ -17,7 +17,6 @@
|
|||||||
"@smithy/eventstream-serde-node": "^2.1.3",
|
"@smithy/eventstream-serde-node": "^2.1.3",
|
||||||
"@smithy/protocol-http": "^3.2.1",
|
"@smithy/protocol-http": "^3.2.1",
|
||||||
"@smithy/signature-v4": "^2.1.3",
|
"@smithy/signature-v4": "^2.1.3",
|
||||||
"@smithy/types": "^2.10.1",
|
|
||||||
"@smithy/util-utf8": "^2.1.1",
|
"@smithy/util-utf8": "^2.1.1",
|
||||||
"axios": "^1.7.4",
|
"axios": "^1.7.4",
|
||||||
"better-sqlite3": "^10.0.0",
|
"better-sqlite3": "^10.0.0",
|
||||||
@@ -52,6 +51,7 @@
|
|||||||
"zod-error": "^1.5.0"
|
"zod-error": "^1.5.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
"@smithy/types": "^3.3.0",
|
||||||
"@types/better-sqlite3": "^7.6.10",
|
"@types/better-sqlite3": "^7.6.10",
|
||||||
"@types/cookie-parser": "^1.4.3",
|
"@types/cookie-parser": "^1.4.3",
|
||||||
"@types/cors": "^2.8.13",
|
"@types/cors": "^2.8.13",
|
||||||
@@ -152,6 +152,17 @@
|
|||||||
"node": ">=14.0.0"
|
"node": ">=14.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@aws-sdk/types/node_modules/@smithy/types": {
|
||||||
|
"version": "2.12.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||||
|
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||||
|
"dependencies": {
|
||||||
|
"tslib": "^2.6.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@aws-sdk/util-utf8-browser": {
|
"node_modules/@aws-sdk/util-utf8-browser": {
|
||||||
"version": "3.259.0",
|
"version": "3.259.0",
|
||||||
"resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz",
|
"resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz",
|
||||||
@@ -1328,6 +1339,17 @@
|
|||||||
"tslib": "^2.5.0"
|
"tslib": "^2.5.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@smithy/eventstream-codec/node_modules/@smithy/types": {
|
||||||
|
"version": "2.12.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||||
|
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||||
|
"dependencies": {
|
||||||
|
"tslib": "^2.6.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@smithy/eventstream-serde-node": {
|
"node_modules/@smithy/eventstream-serde-node": {
|
||||||
"version": "2.1.3",
|
"version": "2.1.3",
|
||||||
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-node/-/eventstream-serde-node-2.1.3.tgz",
|
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-node/-/eventstream-serde-node-2.1.3.tgz",
|
||||||
@@ -1341,6 +1363,17 @@
|
|||||||
"node": ">=14.0.0"
|
"node": ">=14.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@smithy/eventstream-serde-node/node_modules/@smithy/types": {
|
||||||
|
"version": "2.12.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||||
|
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||||
|
"dependencies": {
|
||||||
|
"tslib": "^2.6.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@smithy/eventstream-serde-universal": {
|
"node_modules/@smithy/eventstream-serde-universal": {
|
||||||
"version": "2.1.3",
|
"version": "2.1.3",
|
||||||
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-2.1.3.tgz",
|
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-2.1.3.tgz",
|
||||||
@@ -1354,6 +1387,17 @@
|
|||||||
"node": ">=14.0.0"
|
"node": ">=14.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@smithy/eventstream-serde-universal/node_modules/@smithy/types": {
|
||||||
|
"version": "2.12.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||||
|
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||||
|
"dependencies": {
|
||||||
|
"tslib": "^2.6.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@smithy/is-array-buffer": {
|
"node_modules/@smithy/is-array-buffer": {
|
||||||
"version": "2.1.1",
|
"version": "2.1.1",
|
||||||
"resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.1.1.tgz",
|
"resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.1.1.tgz",
|
||||||
@@ -1377,6 +1421,17 @@
|
|||||||
"node": ">=14.0.0"
|
"node": ">=14.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@smithy/protocol-http/node_modules/@smithy/types": {
|
||||||
|
"version": "2.12.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||||
|
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||||
|
"dependencies": {
|
||||||
|
"tslib": "^2.6.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@smithy/signature-v4": {
|
"node_modules/@smithy/signature-v4": {
|
||||||
"version": "2.1.3",
|
"version": "2.1.3",
|
||||||
"resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-2.1.3.tgz",
|
"resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-2.1.3.tgz",
|
||||||
@@ -1395,17 +1450,29 @@
|
|||||||
"node": ">=14.0.0"
|
"node": ">=14.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@smithy/types": {
|
"node_modules/@smithy/signature-v4/node_modules/@smithy/types": {
|
||||||
"version": "2.10.1",
|
"version": "2.12.0",
|
||||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.10.1.tgz",
|
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||||
"integrity": "sha512-hjQO+4ru4cQ58FluQvKKiyMsFg0A6iRpGm2kqdH8fniyNd2WyanoOsYJfMX/IFLuLxEoW6gnRkNZy1y6fUUhtA==",
|
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"tslib": "^2.5.0"
|
"tslib": "^2.6.2"
|
||||||
},
|
},
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=14.0.0"
|
"node": ">=14.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@smithy/types": {
|
||||||
|
"version": "3.3.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-3.3.0.tgz",
|
||||||
|
"integrity": "sha512-IxvBBCTFDHbVoK7zIxqA1ZOdc4QfM5HM7rGleCuHi7L1wnKv5Pn69xXJQ9hgxH60ZVygH9/JG0jRgtUncE3QUA==",
|
||||||
|
"dev": true,
|
||||||
|
"dependencies": {
|
||||||
|
"tslib": "^2.6.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=16.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@smithy/util-buffer-from": {
|
"node_modules/@smithy/util-buffer-from": {
|
||||||
"version": "2.1.1",
|
"version": "2.1.1",
|
||||||
"resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.1.1.tgz",
|
"resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.1.1.tgz",
|
||||||
@@ -1441,6 +1508,17 @@
|
|||||||
"node": ">=14.0.0"
|
"node": ">=14.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@smithy/util-middleware/node_modules/@smithy/types": {
|
||||||
|
"version": "2.12.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||||
|
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||||
|
"dependencies": {
|
||||||
|
"tslib": "^2.6.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@smithy/util-uri-escape": {
|
"node_modules/@smithy/util-uri-escape": {
|
||||||
"version": "2.1.1",
|
"version": "2.1.1",
|
||||||
"resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-2.1.1.tgz",
|
"resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-2.1.1.tgz",
|
||||||
|
|||||||
+1
-1
@@ -26,7 +26,6 @@
|
|||||||
"@smithy/eventstream-serde-node": "^2.1.3",
|
"@smithy/eventstream-serde-node": "^2.1.3",
|
||||||
"@smithy/protocol-http": "^3.2.1",
|
"@smithy/protocol-http": "^3.2.1",
|
||||||
"@smithy/signature-v4": "^2.1.3",
|
"@smithy/signature-v4": "^2.1.3",
|
||||||
"@smithy/types": "^2.10.1",
|
|
||||||
"@smithy/util-utf8": "^2.1.1",
|
"@smithy/util-utf8": "^2.1.1",
|
||||||
"axios": "^1.7.4",
|
"axios": "^1.7.4",
|
||||||
"better-sqlite3": "^10.0.0",
|
"better-sqlite3": "^10.0.0",
|
||||||
@@ -61,6 +60,7 @@
|
|||||||
"zod-error": "^1.5.0"
|
"zod-error": "^1.5.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
"@smithy/types": "^3.3.0",
|
||||||
"@types/better-sqlite3": "^7.6.10",
|
"@types/better-sqlite3": "^7.6.10",
|
||||||
"@types/cookie-parser": "^1.4.3",
|
"@types/cookie-parser": "^1.4.3",
|
||||||
"@types/cors": "^2.8.13",
|
"@types/cors": "^2.8.13",
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import {
|
|||||||
} from "../../shared/users/schema";
|
} from "../../shared/users/schema";
|
||||||
import { getLastNImages } from "../../shared/file-storage/image-history";
|
import { getLastNImages } from "../../shared/file-storage/image-history";
|
||||||
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
|
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
|
||||||
import { invalidatePowHmacKey } from "../../user/web/pow-captcha";
|
import { invalidatePowChallenges } from "../../user/web/pow-captcha";
|
||||||
|
|
||||||
const router = Router();
|
const router = Router();
|
||||||
|
|
||||||
@@ -323,7 +323,7 @@ router.post("/maintenance", (req, res) => {
|
|||||||
user.disabledReason = "Admin forced expiration.";
|
user.disabledReason = "Admin forced expiration.";
|
||||||
userStore.upsertUser(user);
|
userStore.upsertUser(user);
|
||||||
});
|
});
|
||||||
invalidatePowHmacKey();
|
invalidatePowChallenges();
|
||||||
flash.type = "success";
|
flash.type = "success";
|
||||||
flash.message = `${temps.length} temporary users marked for expiration.`;
|
flash.message = `${temps.length} temporary users marked for expiration.`;
|
||||||
break;
|
break;
|
||||||
@@ -348,6 +348,7 @@ router.post("/maintenance", (req, res) => {
|
|||||||
throw new HttpError(400, "Invalid difficulty" + selected);
|
throw new HttpError(400, "Invalid difficulty" + selected);
|
||||||
}
|
}
|
||||||
config.powDifficultyLevel = selected;
|
config.powDifficultyLevel = selected;
|
||||||
|
invalidatePowChallenges();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case "generateTempIpReport": {
|
case "generateTempIpReport": {
|
||||||
|
|||||||
+1
-1
@@ -519,7 +519,7 @@ function generateSigningKey() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const signingKey = generateSigningKey();
|
const signingKey = generateSigningKey();
|
||||||
export const COOKIE_SECRET = signingKey;
|
export const SECRET_SIGNING_KEY = signingKey;
|
||||||
|
|
||||||
export async function assertConfigIsValid() {
|
export async function assertConfigIsValid() {
|
||||||
if (process.env.MODEL_RATE_LIMIT !== undefined) {
|
if (process.env.MODEL_RATE_LIMIT !== undefined) {
|
||||||
|
|||||||
+1
-4
@@ -1,6 +1,5 @@
|
|||||||
import { Request, RequestHandler, Response, Router } from "express";
|
import { Request, RequestHandler, Router } from "express";
|
||||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||||
import { v4 } from "uuid";
|
|
||||||
import { config } from "../config";
|
import { config } from "../config";
|
||||||
import { logger } from "../logger";
|
import { logger } from "../logger";
|
||||||
import { createQueueMiddleware } from "./queue";
|
import { createQueueMiddleware } from "./queue";
|
||||||
@@ -17,8 +16,6 @@ import {
|
|||||||
createOnProxyResHandler,
|
createOnProxyResHandler,
|
||||||
} from "./middleware/response";
|
} from "./middleware/response";
|
||||||
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
|
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
|
||||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
|
||||||
|
|
||||||
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
|
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
|
||||||
|
|
||||||
let modelsCache: any = null;
|
let modelsCache: any = null;
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ googleAIRouter.post(
|
|||||||
outApi: "google-ai",
|
outApi: "google-ai",
|
||||||
service: "google-ai",
|
service: "google-ai",
|
||||||
},
|
},
|
||||||
{ afterTransform: [maybeReassignModel, setStreamFlag] }
|
{ beforeTransform: [maybeReassignModel], afterTransform: [setStreamFlag] }
|
||||||
),
|
),
|
||||||
googleAIProxy
|
googleAIProxy
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
|
|||||||
const ANTHROPIC_MESSAGES_ENDPOINT = "/v1/messages";
|
const ANTHROPIC_MESSAGES_ENDPOINT = "/v1/messages";
|
||||||
const ANTHROPIC_SONNET_COMPAT_ENDPOINT = "/v1/sonnet";
|
const ANTHROPIC_SONNET_COMPAT_ENDPOINT = "/v1/sonnet";
|
||||||
const ANTHROPIC_OPUS_COMPAT_ENDPOINT = "/v1/opus";
|
const ANTHROPIC_OPUS_COMPAT_ENDPOINT = "/v1/opus";
|
||||||
|
const GOOGLE_AI_COMPLETION_ENDPOINT = "/v1beta/models";
|
||||||
|
|
||||||
export function isTextGenerationRequest(req: Request) {
|
export function isTextGenerationRequest(req: Request) {
|
||||||
return (
|
return (
|
||||||
@@ -27,6 +28,7 @@ export function isTextGenerationRequest(req: Request) {
|
|||||||
ANTHROPIC_MESSAGES_ENDPOINT,
|
ANTHROPIC_MESSAGES_ENDPOINT,
|
||||||
ANTHROPIC_SONNET_COMPAT_ENDPOINT,
|
ANTHROPIC_SONNET_COMPAT_ENDPOINT,
|
||||||
ANTHROPIC_OPUS_COMPAT_ENDPOINT,
|
ANTHROPIC_OPUS_COMPAT_ENDPOINT,
|
||||||
|
GOOGLE_AI_COMPLETION_ENDPOINT,
|
||||||
].some((endpoint) => req.path.startsWith(endpoint))
|
].some((endpoint) => req.path.startsWith(endpoint))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ async function executePreprocessors(
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error.constructor.name === "ZodError") {
|
if (error.constructor.name === "ZodError") {
|
||||||
const msg = error?.issues
|
const msg = error?.issues
|
||||||
?.map((issue: ZodIssue) => issue.message)
|
?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`)
|
||||||
.join("; ");
|
.join("; ");
|
||||||
req.log.warn({ issues: msg }, "Prompt validation failed.");
|
req.log.warn({ issues: msg }, "Prompt validation failed.");
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -30,10 +30,13 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
|||||||
}
|
}
|
||||||
case "anthropic-chat": {
|
case "anthropic-chat": {
|
||||||
req.outputTokens = req.body.max_tokens;
|
req.outputTokens = req.body.max_tokens;
|
||||||
const prompt = {
|
let system = req.body.system ?? "";
|
||||||
system: req.body.system ?? "",
|
if (Array.isArray(system)) {
|
||||||
messages: req.body.messages,
|
system = system
|
||||||
};
|
.map((m: { type: string; text: string }) => m.text)
|
||||||
|
.join("\n");
|
||||||
|
}
|
||||||
|
const prompt = { system, messages: req.body.messages };
|
||||||
result = await countTokens({ req, prompt, service });
|
result = await countTokens({ req, prompt, service });
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import {
|
|||||||
AnthropicV1TextSchema,
|
AnthropicV1TextSchema,
|
||||||
AnthropicV1MessagesSchema,
|
AnthropicV1MessagesSchema,
|
||||||
} from "../../../../shared/api-schemas";
|
} from "../../../../shared/api-schemas";
|
||||||
import { keyPool } from "../../../../shared/key-management";
|
import { AwsBedrockKey, keyPool } from "../../../../shared/key-management";
|
||||||
import { RequestPreprocessor } from "../index";
|
import { RequestPreprocessor } from "../index";
|
||||||
import {
|
import {
|
||||||
AWSMistralV1ChatCompletionsSchema,
|
AWSMistralV1ChatCompletionsSchema,
|
||||||
@@ -40,13 +40,21 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
|
|||||||
// set it so that the stream adapter always selects the correct transformer.
|
// set it so that the stream adapter always selects the correct transformer.
|
||||||
req.headers["anthropic-version"] = "2023-06-01";
|
req.headers["anthropic-version"] = "2023-06-01";
|
||||||
|
|
||||||
|
// If our key has an inference profile compatible with the requested model,
|
||||||
|
// we want to use the inference profile instead of the model ID when calling
|
||||||
|
// InvokeModel as that will give us higher rate limits.
|
||||||
|
const profile =
|
||||||
|
(req.key as AwsBedrockKey).inferenceProfileIds.find((p) =>
|
||||||
|
p.includes(model)
|
||||||
|
) || model;
|
||||||
|
|
||||||
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
|
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
|
||||||
// with the headers generated by the SDK.
|
// with the headers generated by the SDK.
|
||||||
const newRequest = new HttpRequest({
|
const newRequest = new HttpRequest({
|
||||||
method: "POST",
|
method: "POST",
|
||||||
protocol: "https:",
|
protocol: "https:",
|
||||||
hostname: host,
|
hostname: host,
|
||||||
path: `/model/${model}/invoke${stream ? "-with-response-stream" : ""}`,
|
path: `/model/${profile}/invoke${stream ? "-with-response-stream" : ""}`,
|
||||||
headers: {
|
headers: {
|
||||||
["Host"]: host,
|
["Host"]: host,
|
||||||
["content-type"]: "application/json",
|
["content-type"]: "application/json",
|
||||||
@@ -62,7 +70,13 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
|
|||||||
|
|
||||||
const { key, body, inboundApi, outboundApi } = req;
|
const { key, body, inboundApi, outboundApi } = req;
|
||||||
req.log.info(
|
req.log.info(
|
||||||
{ key: key.hash, model: body.model, inboundApi, outboundApi },
|
{
|
||||||
|
key: key.hash,
|
||||||
|
model: body.model,
|
||||||
|
inferenceProfile: profile,
|
||||||
|
inboundApi,
|
||||||
|
outboundApi,
|
||||||
|
},
|
||||||
"Assigned AWS credentials to request"
|
"Assigned AWS credentials to request"
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -130,6 +144,8 @@ function applyAwsStrictValidation(req: Request): unknown {
|
|||||||
temperature: true,
|
temperature: true,
|
||||||
top_k: true,
|
top_k: true,
|
||||||
top_p: true,
|
top_p: true,
|
||||||
|
tools: true,
|
||||||
|
tool_choice: true,
|
||||||
})
|
})
|
||||||
.strip()
|
.strip()
|
||||||
.parse(req.body);
|
.parse(req.body);
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
|
|||||||
req.isStreaming = String(stream) === "true";
|
req.isStreaming = String(stream) === "true";
|
||||||
|
|
||||||
// TODO: This should happen in transform-outbound-payload.ts
|
// TODO: This should happen in transform-outbound-payload.ts
|
||||||
// TODO: Support tools
|
|
||||||
let strippedParams: Record<string, unknown>;
|
let strippedParams: Record<string, unknown>;
|
||||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||||
messages: true,
|
messages: true,
|
||||||
@@ -34,6 +33,8 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
|
|||||||
temperature: true,
|
temperature: true,
|
||||||
top_k: true,
|
top_k: true,
|
||||||
top_p: true,
|
top_p: true,
|
||||||
|
tools: true,
|
||||||
|
tool_choice: true,
|
||||||
stream: true,
|
stream: true,
|
||||||
})
|
})
|
||||||
.strip()
|
.strip()
|
||||||
|
|||||||
@@ -17,7 +17,17 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
|||||||
const notTransformable =
|
const notTransformable =
|
||||||
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
|
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
|
||||||
|
|
||||||
if (alreadyTransformed || notTransformable) return;
|
if (alreadyTransformed) {
|
||||||
|
return;
|
||||||
|
} else if (notTransformable) {
|
||||||
|
// This is probably an indication of a bug in the proxy.
|
||||||
|
const { inboundApi, outboundApi, method, path } = req;
|
||||||
|
req.log.warn(
|
||||||
|
{ inboundApi, outboundApi, method, path },
|
||||||
|
"`transformOutboundPayload` called on a non-transformable request."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
applyMistralPromptFixes(req);
|
applyMistralPromptFixes(req);
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,13 @@ const getPromptForRequest = (
|
|||||||
case "mistral-ai":
|
case "mistral-ai":
|
||||||
return req.body.messages;
|
return req.body.messages;
|
||||||
case "anthropic-chat":
|
case "anthropic-chat":
|
||||||
return { system: req.body.system, messages: req.body.messages };
|
let system = req.body.system;
|
||||||
|
if (Array.isArray(system)) {
|
||||||
|
system = system
|
||||||
|
.map((m: { type: string; text: string }) => m.text)
|
||||||
|
.join("\n");
|
||||||
|
}
|
||||||
|
return { system, messages: req.body.messages };
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
case "anthropic-text":
|
case "anthropic-text":
|
||||||
case "mistral-text":
|
case "mistral-text":
|
||||||
|
|||||||
+12
-48
@@ -22,7 +22,7 @@ import {
|
|||||||
} from "../shared/models";
|
} from "../shared/models";
|
||||||
import { initializeSseStream } from "../shared/streaming";
|
import { initializeSseStream } from "../shared/streaming";
|
||||||
import { logger } from "../logger";
|
import { logger } from "../logger";
|
||||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
import { getUniqueIps } from "./rate-limit";
|
||||||
import { RequestPreprocessor } from "./middleware/request";
|
import { RequestPreprocessor } from "./middleware/request";
|
||||||
import { handleProxyError } from "./middleware/common";
|
import { handleProxyError } from "./middleware/common";
|
||||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||||
@@ -31,7 +31,9 @@ const queue: Request[] = [];
|
|||||||
const log = logger.child({ module: "request-queue" });
|
const log = logger.child({ module: "request-queue" });
|
||||||
|
|
||||||
/** Maximum number of queue slots for individual users. */
|
/** Maximum number of queue slots for individual users. */
|
||||||
const USER_CONCURRENCY_LIMIT = parseInt(process.env.USER_CONCURRENCY_LIMIT ?? "1");
|
const USER_CONCURRENCY_LIMIT = parseInt(
|
||||||
|
process.env.USER_CONCURRENCY_LIMIT ?? "1"
|
||||||
|
);
|
||||||
/** Maximum number of queue slots for Agnai.chat requests. */
|
/** Maximum number of queue slots for Agnai.chat requests. */
|
||||||
const AGNAI_CONCURRENCY_LIMIT = USER_CONCURRENCY_LIMIT * 5;
|
const AGNAI_CONCURRENCY_LIMIT = USER_CONCURRENCY_LIMIT * 5;
|
||||||
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
|
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
|
||||||
@@ -58,39 +60,20 @@ const QUEUE_JOIN_TIMEOUT = 5000;
|
|||||||
function getIdentifier(req: Request) {
|
function getIdentifier(req: Request) {
|
||||||
if (req.user) return req.user.token;
|
if (req.user) return req.user.token;
|
||||||
if (req.risuToken) return req.risuToken;
|
if (req.risuToken) return req.risuToken;
|
||||||
if (isFromSharedIp(req)) return "shared-ip";
|
// if (isFromSharedIp(req)) return "shared-ip";
|
||||||
return req.ip;
|
return req.ip;
|
||||||
}
|
}
|
||||||
|
|
||||||
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
|
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
|
||||||
getIdentifier(queued) === getIdentifier(incoming);
|
getIdentifier(queued) === getIdentifier(incoming);
|
||||||
|
|
||||||
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
|
|
||||||
|
|
||||||
async function enqueue(req: Request) {
|
async function enqueue(req: Request) {
|
||||||
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
|
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
|
||||||
let isGuest = req.user?.token === undefined;
|
|
||||||
|
|
||||||
// Requests from shared IP addresses such as Agnai.chat are exempt from IP-
|
if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT) {
|
||||||
// based rate limiting but can only occupy a certain number of slots in the
|
throw new TooManyRequestsError(
|
||||||
// queue. Authenticated users always get a single spot in the queue.
|
"Your IP or user token already has another request in the queue."
|
||||||
const isSharedIp = isFromSharedIp(req);
|
);
|
||||||
const maxConcurrentQueuedRequests =
|
|
||||||
isGuest && isSharedIp ? AGNAI_CONCURRENCY_LIMIT : USER_CONCURRENCY_LIMIT;
|
|
||||||
if (enqueuedRequestCount >= maxConcurrentQueuedRequests) {
|
|
||||||
if (isSharedIp) {
|
|
||||||
// Re-enqueued requests are not counted towards the limit since they
|
|
||||||
// already made it through the queue once.
|
|
||||||
if (req.retryCount === 0) {
|
|
||||||
throw new TooManyRequestsError(
|
|
||||||
"Too many agnai.chat requests are already queued"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw new TooManyRequestsError(
|
|
||||||
"Your IP or user token already has another request in the queue."
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// shitty hack to remove hpm's event listeners on retried requests
|
// shitty hack to remove hpm's event listeners on retried requests
|
||||||
@@ -146,19 +129,7 @@ export async function reenqueueRequest(req: Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function getQueueForPartition(partition: ModelFamily): Request[] {
|
function getQueueForPartition(partition: ModelFamily): Request[] {
|
||||||
return queue
|
return queue.filter((req) => getModelFamilyForRequest(req) === partition);
|
||||||
.filter((req) => getModelFamilyForRequest(req) === partition)
|
|
||||||
.sort((a, b) => {
|
|
||||||
// Certain requests are exempted from IP-based rate limiting because they
|
|
||||||
// come from a shared IP address. To prevent these requests from starving
|
|
||||||
// out other requests during periods of high traffic, we sort them to the
|
|
||||||
// end of the queue.
|
|
||||||
const aIsExempted = isFromSharedIp(a);
|
|
||||||
const bIsExempted = isFromSharedIp(b);
|
|
||||||
if (aIsExempted && !bIsExempted) return 1;
|
|
||||||
if (!aIsExempted && bIsExempted) return -1;
|
|
||||||
return 0;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function dequeue(partition: ModelFamily): Request | undefined {
|
export function dequeue(partition: ModelFamily): Request | undefined {
|
||||||
@@ -261,7 +232,6 @@ let waitTimes: {
|
|||||||
partition: ModelFamily;
|
partition: ModelFamily;
|
||||||
start: number;
|
start: number;
|
||||||
end: number;
|
end: number;
|
||||||
isDeprioritized: boolean;
|
|
||||||
}[] = [];
|
}[] = [];
|
||||||
|
|
||||||
/** Adds a successful request to the list of wait times. */
|
/** Adds a successful request to the list of wait times. */
|
||||||
@@ -270,7 +240,6 @@ export function trackWaitTime(req: Request) {
|
|||||||
partition: getModelFamilyForRequest(req),
|
partition: getModelFamilyForRequest(req),
|
||||||
start: req.startTime!,
|
start: req.startTime!,
|
||||||
end: req.queueOutTime ?? Date.now(),
|
end: req.queueOutTime ?? Date.now(),
|
||||||
isDeprioritized: isFromSharedIp(req),
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,8 +265,7 @@ function calculateWaitTime(partition: ModelFamily) {
|
|||||||
.filter((wait) => {
|
.filter((wait) => {
|
||||||
const isSamePartition = wait.partition === partition;
|
const isSamePartition = wait.partition === partition;
|
||||||
const isRecent = now - wait.end < 300 * 1000;
|
const isRecent = now - wait.end < 300 * 1000;
|
||||||
const isNormalPriority = !wait.isDeprioritized;
|
return isSamePartition && isRecent;
|
||||||
return isSamePartition && isRecent && isNormalPriority;
|
|
||||||
})
|
})
|
||||||
.map((wait) => wait.end - wait.start);
|
.map((wait) => wait.end - wait.start);
|
||||||
const recentAverage = recentWaits.length
|
const recentAverage = recentWaits.length
|
||||||
@@ -311,11 +279,7 @@ function calculateWaitTime(partition: ModelFamily) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const currentWaits = queue
|
const currentWaits = queue
|
||||||
.filter((req) => {
|
.filter((req) => getModelFamilyForRequest(req) === partition)
|
||||||
const isSamePartition = getModelFamilyForRequest(req) === partition;
|
|
||||||
const isNormalPriority = !isFromSharedIp(req);
|
|
||||||
return isSamePartition && isNormalPriority;
|
|
||||||
})
|
|
||||||
.map((req) => now - req.startTime!);
|
.map((req) => now - req.startTime!);
|
||||||
const longestCurrentWait = Math.max(...currentWaits, 0);
|
const longestCurrentWait = Math.max(...currentWaits, 0);
|
||||||
|
|
||||||
|
|||||||
+15
-32
@@ -1,14 +1,6 @@
|
|||||||
import { Request, Response, NextFunction } from "express";
|
import { Request, Response, NextFunction } from "express";
|
||||||
import { config } from "../config";
|
import { config } from "../config";
|
||||||
|
|
||||||
export const SHARED_IP_ADDRESSES = new Set([
|
|
||||||
// Agnai.chat
|
|
||||||
"157.230.249.32", // old
|
|
||||||
"157.245.148.56",
|
|
||||||
"174.138.29.50",
|
|
||||||
"209.97.162.44",
|
|
||||||
]);
|
|
||||||
|
|
||||||
const ONE_MINUTE_MS = 60 * 1000;
|
const ONE_MINUTE_MS = 60 * 1000;
|
||||||
|
|
||||||
type Timestamp = number;
|
type Timestamp = number;
|
||||||
@@ -20,7 +12,10 @@ const exemptedRequests: Timestamp[] = [];
|
|||||||
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
|
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
|
||||||
attempt > now - ONE_MINUTE_MS;
|
attempt > now - ONE_MINUTE_MS;
|
||||||
|
|
||||||
const getTryAgainInMs = (ip: string, type: "text" | "image") => {
|
/**
|
||||||
|
* Returns duration in seconds to wait before retrying for Retry-After header.
|
||||||
|
*/
|
||||||
|
const getRetryAfter = (ip: string, type: "text" | "image") => {
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
const attempts = lastAttempts.get(ip) || [];
|
const attempts = lastAttempts.get(ip) || [];
|
||||||
const validAttempts = attempts.filter(isRecentAttempt(now));
|
const validAttempts = attempts.filter(isRecentAttempt(now));
|
||||||
@@ -29,7 +24,7 @@ const getTryAgainInMs = (ip: string, type: "text" | "image") => {
|
|||||||
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
|
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
|
||||||
|
|
||||||
if (validAttempts.length >= limit) {
|
if (validAttempts.length >= limit) {
|
||||||
return validAttempts[0] - now + ONE_MINUTE_MS;
|
return (validAttempts[0] - now + ONE_MINUTE_MS) / 1000;
|
||||||
} else {
|
} else {
|
||||||
lastAttempts.set(ip, [...validAttempts, now]);
|
lastAttempts.set(ip, [...validAttempts, now]);
|
||||||
return 0;
|
return 0;
|
||||||
@@ -96,22 +91,11 @@ export const ipLimiter = async (
|
|||||||
if (!textLimit && !imageLimit) return next();
|
if (!textLimit && !imageLimit) return next();
|
||||||
if (req.user?.type === "special") return next();
|
if (req.user?.type === "special") return next();
|
||||||
|
|
||||||
// Exempts Agnai.chat from IP-based rate limiting because its IPs are shared
|
const path = req.baseUrl + req.path;
|
||||||
// by many users. Instead, the request queue will limit the number of such
|
const type =
|
||||||
// requests that may wait in the queue at a time, and sorts them to the end to
|
path.includes("openai-image") || path.includes("images/generations")
|
||||||
// let individual users go first.
|
? "image"
|
||||||
if (SHARED_IP_ADDRESSES.has(req.ip)) {
|
: "text";
|
||||||
exemptedRequests.push(Date.now());
|
|
||||||
req.log.info(
|
|
||||||
{ ip: req.ip, recentExemptions: exemptedRequests.length },
|
|
||||||
"Exempting Agnai request from rate limiting."
|
|
||||||
);
|
|
||||||
return next();
|
|
||||||
}
|
|
||||||
|
|
||||||
const type = (req.baseUrl + req.path).includes("openai-image")
|
|
||||||
? "image"
|
|
||||||
: "text";
|
|
||||||
const limit = type === "image" ? imageLimit : textLimit;
|
const limit = type === "image" ? imageLimit : textLimit;
|
||||||
|
|
||||||
// If user is authenticated, key rate limiting by their token. Otherwise, key
|
// If user is authenticated, key rate limiting by their token. Otherwise, key
|
||||||
@@ -123,15 +107,14 @@ export const ipLimiter = async (
|
|||||||
res.set("X-RateLimit-Remaining", remaining.toString());
|
res.set("X-RateLimit-Remaining", remaining.toString());
|
||||||
res.set("X-RateLimit-Reset", reset.toString());
|
res.set("X-RateLimit-Reset", reset.toString());
|
||||||
|
|
||||||
const tryAgainInMs = getTryAgainInMs(rateLimitKey, type);
|
const retryAfterTime = getRetryAfter(rateLimitKey, type);
|
||||||
if (tryAgainInMs > 0) {
|
if (retryAfterTime > 0) {
|
||||||
res.set("Retry-After", tryAgainInMs.toString());
|
const waitSec = Math.ceil(retryAfterTime).toString();
|
||||||
|
res.set("Retry-After", waitSec);
|
||||||
res.status(429).json({
|
res.status(429).json({
|
||||||
error: {
|
error: {
|
||||||
type: "proxy_rate_limited",
|
type: "proxy_rate_limited",
|
||||||
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil(
|
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${waitSec} seconds.`,
|
||||||
tryAgainInMs / 1000
|
|
||||||
)} seconds.`,
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ app.use(
|
|||||||
// Don't log the prompt text on transform errors
|
// Don't log the prompt text on transform errors
|
||||||
"body.messages",
|
"body.messages",
|
||||||
"body.prompt",
|
"body.prompt",
|
||||||
|
"body.contents",
|
||||||
],
|
],
|
||||||
censor: "********",
|
censor: "********",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ const AnthropicV1BaseSchema = z
|
|||||||
top_k: z.coerce.number().optional(),
|
top_k: z.coerce.number().optional(),
|
||||||
top_p: z.coerce.number().optional(),
|
top_p: z.coerce.number().optional(),
|
||||||
metadata: z.object({ user_id: z.string().optional() }).optional(),
|
metadata: z.object({ user_id: z.string().optional() }).optional(),
|
||||||
|
tools: z.array(z.any()).optional(),
|
||||||
|
tool_choice: z.any().optional(),
|
||||||
})
|
})
|
||||||
|
.omit(
|
||||||
|
Boolean(config.allowOpenAIToolUsage) ? {} : { tools: true, tool_choice: true }
|
||||||
|
)
|
||||||
.strip();
|
.strip();
|
||||||
|
|
||||||
// https://docs.anthropic.com/claude/reference/complete_post [deprecated]
|
// https://docs.anthropic.com/claude/reference/complete_post [deprecated]
|
||||||
@@ -44,6 +49,18 @@ const AnthropicV1MessageMultimodalContentSchema = z.array(
|
|||||||
data: z.string(),
|
data: z.string(),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
z.object({
|
||||||
|
type: z.literal("tool_use"),
|
||||||
|
id: z.string(),
|
||||||
|
name: z.string(),
|
||||||
|
input: z.object({}).passthrough(),
|
||||||
|
}),
|
||||||
|
z.object({
|
||||||
|
type: z.literal("tool_result"),
|
||||||
|
tool_use_id: z.string(),
|
||||||
|
is_error: z.boolean().optional(),
|
||||||
|
content: z.union([z.string(), z.object({}).passthrough()]).optional(),
|
||||||
|
}),
|
||||||
])
|
])
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -63,7 +80,12 @@ export const AnthropicV1MessagesSchema = AnthropicV1BaseSchema.merge(
|
|||||||
.number()
|
.number()
|
||||||
.int()
|
.int()
|
||||||
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
|
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
|
||||||
system: z.string().optional(),
|
system: z
|
||||||
|
.union([
|
||||||
|
z.string(),
|
||||||
|
z.array(z.object({ type: z.literal("text"), text: z.string() })),
|
||||||
|
])
|
||||||
|
.optional(),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
export type AnthropicChatMessage = z.infer<
|
export type AnthropicChatMessage = z.infer<
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ export const GoogleAIV1GenerateContentSchema = z
|
|||||||
topP: z.number().optional(),
|
topP: z.number().optional(),
|
||||||
topK: z.number().optional(),
|
topK: z.number().optional(),
|
||||||
stopSequences: z.array(z.string().max(500)).max(5).optional(),
|
stopSequences: z.array(z.string().max(500)).max(5).optional(),
|
||||||
}),
|
}).default({}),
|
||||||
})
|
})
|
||||||
.strip();
|
.strip();
|
||||||
export type GoogleAIChatMessage = z.infer<
|
export type GoogleAIChatMessage = z.infer<
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
/** Module for generating and verifying HMAC signatures. */
|
||||||
|
|
||||||
|
import crypto from "crypto";
|
||||||
|
import { SECRET_SIGNING_KEY } from "../config";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates a HMAC signature for the given message. Optionally salts the
|
||||||
|
* key with a provided string.
|
||||||
|
*/
|
||||||
|
export function signMessage(msg: any, salt: string = ""): string {
|
||||||
|
const hmac = crypto.createHmac("sha256", SECRET_SIGNING_KEY + salt);
|
||||||
|
if (typeof msg === "object") {
|
||||||
|
hmac.update(JSON.stringify(msg));
|
||||||
|
} else {
|
||||||
|
hmac.update(msg);
|
||||||
|
}
|
||||||
|
return hmac.digest("hex");
|
||||||
|
}
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
import { doubleCsrf } from "csrf-csrf";
|
import { doubleCsrf } from "csrf-csrf";
|
||||||
import express from "express";
|
import express from "express";
|
||||||
import { config, COOKIE_SECRET } from "../config";
|
import { config, SECRET_SIGNING_KEY } from "../config";
|
||||||
|
|
||||||
const { generateToken, doubleCsrfProtection } = doubleCsrf({
|
const { generateToken, doubleCsrfProtection } = doubleCsrf({
|
||||||
getSecret: () => COOKIE_SECRET,
|
getSecret: () => SECRET_SIGNING_KEY,
|
||||||
cookieName: "csrf",
|
cookieName: "csrf",
|
||||||
cookieOptions: {
|
cookieOptions: {
|
||||||
sameSite: "strict",
|
sameSite: "strict",
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||||
import { SignatureV4 } from "@smithy/signature-v4";
|
import { SignatureV4 } from "@smithy/signature-v4";
|
||||||
import { HttpRequest } from "@smithy/protocol-http";
|
import { HttpRequest } from "@smithy/protocol-http";
|
||||||
import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
|
import axios, { AxiosError, AxiosHeaders, AxiosRequestConfig } from "axios";
|
||||||
import { URL } from "url";
|
import { URL } from "url";
|
||||||
|
import { config } from "../../../config";
|
||||||
|
import { getAwsBedrockModelFamily } from "../../models";
|
||||||
import { KeyCheckerBase } from "../key-checker-base";
|
import { KeyCheckerBase } from "../key-checker-base";
|
||||||
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
|
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
|
||||||
import { getAwsBedrockModelFamily } from "../../models";
|
|
||||||
import { config } from "../../../config";
|
|
||||||
|
|
||||||
type ParentModelId = string;
|
type ParentModelId = string;
|
||||||
type AliasModelId = string;
|
type AliasModelId = string;
|
||||||
@@ -24,6 +24,8 @@ const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [
|
|||||||
["mistral.mistral-large-2407-v1:0"],
|
["mistral.mistral-large-2407-v1:0"],
|
||||||
["mistral.mistral-small-2402-v1:0"], // Seems to return 400
|
["mistral.mistral-small-2402-v1:0"], // Seems to return 400
|
||||||
];
|
];
|
||||||
|
|
||||||
|
const KEY_CHECK_BATCH_SIZE = 2; // AWS checker needs to do lots of concurrent requests so should lower the batch size
|
||||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||||
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
||||||
const AMZ_HOST =
|
const AMZ_HOST =
|
||||||
@@ -31,6 +33,8 @@ const AMZ_HOST =
|
|||||||
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
|
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
|
||||||
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
|
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
|
||||||
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
|
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
|
||||||
|
const GET_LIST_INFERENCE_PROFILES_URL = (region: string) =>
|
||||||
|
`https://bedrock.${region}.amazonaws.com/inference-profiles?maxResults=1000`;
|
||||||
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
|
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
|
||||||
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
|
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
|
||||||
const TEST_MESSAGES = [
|
const TEST_MESSAGES = [
|
||||||
@@ -40,6 +44,22 @@ const TEST_MESSAGES = [
|
|||||||
|
|
||||||
type AwsError = { error: {} };
|
type AwsError = { error: {} };
|
||||||
|
|
||||||
|
type GetInferenceProfilesResponse = {
|
||||||
|
inferenceProfileSummaries: {
|
||||||
|
inferenceProfileId: string;
|
||||||
|
inferenceProfileName: string;
|
||||||
|
inferenceProfileArn: string;
|
||||||
|
description?: string;
|
||||||
|
createdAt?: string;
|
||||||
|
updatedAt?: string;
|
||||||
|
status: "ACTIVE" | unknown;
|
||||||
|
type: "SYSTEM_DEFINED" | unknown;
|
||||||
|
models: {
|
||||||
|
modelArn?: string;
|
||||||
|
}[];
|
||||||
|
}[];
|
||||||
|
};
|
||||||
|
|
||||||
type GetLoggingConfigResponse = {
|
type GetLoggingConfigResponse = {
|
||||||
loggingConfig: null | {
|
loggingConfig: null | {
|
||||||
cloudWatchConfig: null | unknown;
|
cloudWatchConfig: null | unknown;
|
||||||
@@ -58,6 +78,7 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
service: "aws",
|
service: "aws",
|
||||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||||
|
keyCheckBatchSize: KEY_CHECK_BATCH_SIZE,
|
||||||
updateKey,
|
updateKey,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -66,38 +87,52 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
const isInitialCheck = !key.lastChecked;
|
const isInitialCheck = !key.lastChecked;
|
||||||
|
|
||||||
if (isInitialCheck) {
|
if (isInitialCheck) {
|
||||||
// Perform checks for all parent model IDs
|
try {
|
||||||
const results = await Promise.all(
|
await this.checkInferenceProfiles(key);
|
||||||
KNOWN_MODEL_IDS.filter(([model]) =>
|
} catch (e) {
|
||||||
// Skip checks for models that are disabled anyway
|
const asError = e as AxiosError<AwsError>;
|
||||||
config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model))
|
const data = asError.response?.data;
|
||||||
).map(async ([model, ...aliases]) => ({
|
|
||||||
models: [model, ...aliases],
|
|
||||||
success: await this.invokeModel(model, key),
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
|
|
||||||
// Filter out models that are disabled
|
|
||||||
const modelIds = results
|
|
||||||
.filter(({ success }) => success)
|
|
||||||
.flatMap(({ models }) => models);
|
|
||||||
|
|
||||||
if (modelIds.length === 0) {
|
|
||||||
this.log.warn(
|
this.log.warn(
|
||||||
{ key: key.hash },
|
{ key: key.hash, error: e.message, data },
|
||||||
"Key does not have access to any models; disabling."
|
"Cannot list inference profiles.\n\
|
||||||
|
Principal may be missing `AmazonBedrockFullAccess`, or has no policy allowing action `bedrock:ListInferenceProfiles` against resource `arn:aws:bedrock:*:*:inference-profile/*`.\n\
|
||||||
|
Requests will be made without inference profiles using on-demand quotas, which may be subject to more restrictive rate limits.\n\
|
||||||
|
See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-prereq.html."
|
||||||
);
|
);
|
||||||
return this.updateKey(key.hash, { isDisabled: true });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
this.updateKey(key.hash, {
|
|
||||||
modelIds,
|
|
||||||
modelFamilies: Array.from(
|
|
||||||
new Set(modelIds.map(getAwsBedrockModelFamily))
|
|
||||||
),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Perform checks for all parent model IDs
|
||||||
|
const results = await Promise.all(
|
||||||
|
KNOWN_MODEL_IDS.filter(([model]) =>
|
||||||
|
// Skip checks for models that are disabled anyway
|
||||||
|
config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model))
|
||||||
|
).map(async ([model, ...aliases]) => ({
|
||||||
|
models: [model, ...aliases],
|
||||||
|
success: await this.invokeModel(model, key),
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
|
||||||
|
// Filter out models that are disabled
|
||||||
|
const modelIds = results
|
||||||
|
.filter(({ success }) => success)
|
||||||
|
.flatMap(({ models }) => models);
|
||||||
|
|
||||||
|
if (modelIds.length === 0) {
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash },
|
||||||
|
"Key does not have access to any models; disabling."
|
||||||
|
);
|
||||||
|
return this.updateKey(key.hash, { isDisabled: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
this.updateKey(key.hash, {
|
||||||
|
modelIds,
|
||||||
|
modelFamilies: Array.from(
|
||||||
|
new Set(modelIds.map(getAwsBedrockModelFamily))
|
||||||
|
),
|
||||||
|
});
|
||||||
|
|
||||||
this.log.info(
|
this.log.info(
|
||||||
{
|
{
|
||||||
key: key.hash,
|
key: key.hash,
|
||||||
@@ -179,6 +214,36 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
key: AwsBedrockKey
|
key: AwsBedrockKey
|
||||||
): Promise<boolean> {
|
): Promise<boolean> {
|
||||||
if (model.includes("claude")) {
|
if (model.includes("claude")) {
|
||||||
|
// If inference profiles are available, try testing model with them.
|
||||||
|
// If they are not available or the invocation fails with the inference
|
||||||
|
// profile, fall back to regular model ID.
|
||||||
|
const { region } = AwsKeyChecker.getCredentialsFromKey(key);
|
||||||
|
const continent = region.split("-")[0];
|
||||||
|
const profile = key.inferenceProfileIds.find(
|
||||||
|
(id) => `${continent}.${model}` === id
|
||||||
|
);
|
||||||
|
|
||||||
|
if (profile) {
|
||||||
|
this.log.debug(
|
||||||
|
{ key: key.hash, model, profile },
|
||||||
|
"Testing model via inference profile."
|
||||||
|
);
|
||||||
|
let result: boolean;
|
||||||
|
try {
|
||||||
|
result = await this.testClaudeModel(key, profile);
|
||||||
|
} catch (e) {
|
||||||
|
this.log.error(
|
||||||
|
{ key: key.hash, model, profile, error: e.message },
|
||||||
|
"Error testing model with inference profile; trying model ID directly."
|
||||||
|
);
|
||||||
|
result = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the profile worked, we'll return success. Caller will add the
|
||||||
|
// model (not the profile) to the list of enabled models, but the
|
||||||
|
// profile will be used when the key is used for inference.
|
||||||
|
if (result) return true;
|
||||||
|
}
|
||||||
return this.testClaudeModel(key, model);
|
return this.testClaudeModel(key, model);
|
||||||
} else if (model.includes("mistral")) {
|
} else if (model.includes("mistral")) {
|
||||||
return this.testMistralModel(key, model);
|
return this.testMistralModel(key, model);
|
||||||
@@ -222,6 +287,10 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
status === 403 &&
|
status === 403 &&
|
||||||
errorMessage?.match(/access to the model with the specified model ID/)
|
errorMessage?.match(/access to the model with the specified model ID/)
|
||||||
) {
|
) {
|
||||||
|
this.log.debug(
|
||||||
|
{ key: key.hash, model, errorType, data, status, headers },
|
||||||
|
"Model is not available (principal does not have access)."
|
||||||
|
);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +299,7 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
if (status === 404) {
|
if (status === 404) {
|
||||||
this.log.debug(
|
this.log.debug(
|
||||||
{ region: creds.region, model, key: key.hash },
|
{ region: creds.region, model, key: key.hash },
|
||||||
"Model not supported in this AWS region."
|
"Model is not available (not supported in this AWS region)."
|
||||||
);
|
);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -242,14 +311,14 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
if (!correctErrorType || !correctErrorMessage) {
|
if (!correctErrorType || !correctErrorMessage) {
|
||||||
this.log.debug(
|
this.log.debug(
|
||||||
{ key: key.hash, model, errorType, data, status },
|
{ key: key.hash, model, errorType, data, status },
|
||||||
"AWS InvokeModel test unsuccessful."
|
"Model is not available (request rejected)."
|
||||||
);
|
);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.log.debug(
|
this.log.debug(
|
||||||
{ key: key.hash, model, errorType, data, status },
|
{ key: key.hash, model, errorType, data, status },
|
||||||
"AWS InvokeModel test successful."
|
"Model is available."
|
||||||
);
|
);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -283,7 +352,7 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
if (status === 403 || status === 404) {
|
if (status === 403 || status === 404) {
|
||||||
this.log.debug(
|
this.log.debug(
|
||||||
{ key: key.hash, model, errorType, data, status },
|
{ key: key.hash, model, errorType, data, status },
|
||||||
"AWS InvokeModel test returned 403 or 404."
|
"Model is not available (no access or unsupported region)."
|
||||||
);
|
);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -293,18 +362,38 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
if (isBadRequest && !isValidationError) {
|
if (isBadRequest && !isValidationError) {
|
||||||
this.log.debug(
|
this.log.debug(
|
||||||
{ key: key.hash, model, errorType, data, status, headers },
|
{ key: key.hash, model, errorType, data, status, headers },
|
||||||
"AWS InvokeModel test returned 400 but not a validation error."
|
"Model is not available (request rejected)."
|
||||||
);
|
);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.log.debug(
|
this.log.debug(
|
||||||
{ key: key.hash, model, errorType, data, status },
|
{ key: key.hash, model, errorType, data, status },
|
||||||
"AWS InvokeModel test successful."
|
"Model is available."
|
||||||
);
|
);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async checkInferenceProfiles(key: AwsBedrockKey) {
|
||||||
|
const creds = AwsKeyChecker.getCredentialsFromKey(key);
|
||||||
|
const req: AxiosRequestConfig = {
|
||||||
|
method: "GET",
|
||||||
|
url: GET_LIST_INFERENCE_PROFILES_URL(creds.region),
|
||||||
|
headers: { accept: "application/json" },
|
||||||
|
};
|
||||||
|
await AwsKeyChecker.signRequestForAws(req, key);
|
||||||
|
const { data } = await axios.request<GetInferenceProfilesResponse>(req);
|
||||||
|
const { inferenceProfileSummaries } = data;
|
||||||
|
const profileIds = inferenceProfileSummaries.map(
|
||||||
|
(p) => p.inferenceProfileId
|
||||||
|
);
|
||||||
|
this.log.debug(
|
||||||
|
{ key: key.hash, profileIds, region: creds.region },
|
||||||
|
"Inference profiles found."
|
||||||
|
);
|
||||||
|
this.updateKey(key.hash, { inferenceProfileIds: profileIds });
|
||||||
|
}
|
||||||
|
|
||||||
private async checkLoggingConfiguration(key: AwsBedrockKey) {
|
private async checkLoggingConfiguration(key: AwsBedrockKey) {
|
||||||
if (config.allowAwsLogging) {
|
if (config.allowAwsLogging) {
|
||||||
// Don't check logging status if we're allowing it to reduce API calls.
|
// Don't check logging status if we're allowing it to reduce API calls.
|
||||||
@@ -373,7 +462,8 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
|||||||
method,
|
method,
|
||||||
protocol: "https:",
|
protocol: "https:",
|
||||||
hostname: url.hostname,
|
hostname: url.hostname,
|
||||||
path: url.pathname + url.search,
|
path: url.pathname,
|
||||||
|
query: Object.fromEntries(url.searchParams),
|
||||||
headers: { Host: url.hostname, ...plainHeaders },
|
headers: { Host: url.hostname, ...plainHeaders },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
|||||||
*/
|
*/
|
||||||
awsLoggingStatus: "unknown" | "disabled" | "enabled";
|
awsLoggingStatus: "unknown" | "disabled" | "enabled";
|
||||||
modelIds: string[];
|
modelIds: string[];
|
||||||
|
inferenceProfileIds: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -72,6 +73,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
|||||||
.slice(0, 8)}`,
|
.slice(0, 8)}`,
|
||||||
lastChecked: 0,
|
lastChecked: 0,
|
||||||
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
|
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||||
|
inferenceProfileIds: [],
|
||||||
["aws-claudeTokens"]: 0,
|
["aws-claudeTokens"]: 0,
|
||||||
["aws-claude-opusTokens"]: 0,
|
["aws-claude-opusTokens"]: 0,
|
||||||
["aws-mistral-tinyTokens"]: 0,
|
["aws-mistral-tinyTokens"]: 0,
|
||||||
@@ -135,7 +137,21 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
/**
|
||||||
|
* Comparator for prioritizing keys on inference profile compatibility.
|
||||||
|
* Requests made via inference profiles have higher rate limits so we want
|
||||||
|
* to use keys with compatible inference profiles first.
|
||||||
|
*/
|
||||||
|
const hasInferenceProfile = (
|
||||||
|
a: AwsBedrockKey,
|
||||||
|
b: AwsBedrockKey
|
||||||
|
) => {
|
||||||
|
const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model));
|
||||||
|
const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model));
|
||||||
|
return aMatch - bMatch;
|
||||||
|
};
|
||||||
|
|
||||||
|
const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0];
|
||||||
selectedKey.lastUsed = Date.now();
|
selectedKey.lastUsed = Date.now();
|
||||||
this.throttle(selectedKey.hash);
|
this.throttle(selectedKey.hash);
|
||||||
return { ...selectedKey };
|
return { ...selectedKey };
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ import { GcpModelFamily } from "../../models";
|
|||||||
|
|
||||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||||
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
||||||
const GCP_HOST =
|
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
||||||
process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
|
||||||
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
|
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
|
||||||
`https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
|
`https://${GCP_HOST.replace(
|
||||||
|
"%REGION%",
|
||||||
|
region
|
||||||
|
)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
|
||||||
const TEST_MESSAGES = [
|
const TEST_MESSAGES = [
|
||||||
{ role: "user", content: "Hi!" },
|
{ role: "user", content: "Hi!" },
|
||||||
{ role: "assistant", content: "Hello!" },
|
{ role: "assistant", content: "Hello!" },
|
||||||
@@ -23,6 +25,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
service: "gcp",
|
service: "gcp",
|
||||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||||
|
recurringChecksEnabled: false,
|
||||||
updateKey,
|
updateKey,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -38,9 +41,8 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
|
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
|
||||||
];
|
];
|
||||||
|
|
||||||
const [sonnet, haiku, opus, sonnet35] =
|
const [sonnet, haiku, opus, sonnet35] = await Promise.all(checks);
|
||||||
await Promise.all(checks);
|
|
||||||
|
|
||||||
this.log.debug(
|
this.log.debug(
|
||||||
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
|
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
|
||||||
"GCP model initial tests complete."
|
"GCP model initial tests complete."
|
||||||
@@ -66,20 +68,17 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
if (key.haikuEnabled) {
|
if (key.haikuEnabled) {
|
||||||
await this.invokeModel("claude-3-haiku@20240307", key, false)
|
await this.invokeModel("claude-3-haiku@20240307", key, false);
|
||||||
} else if (key.sonnetEnabled) {
|
} else if (key.sonnetEnabled) {
|
||||||
await this.invokeModel("claude-3-sonnet@20240229", key, false)
|
await this.invokeModel("claude-3-sonnet@20240229", key, false);
|
||||||
} else if (key.sonnet35Enabled) {
|
} else if (key.sonnet35Enabled) {
|
||||||
await this.invokeModel("claude-3-5-sonnet@20240620", key, false)
|
await this.invokeModel("claude-3-5-sonnet@20240620", key, false);
|
||||||
} else {
|
} else {
|
||||||
await this.invokeModel("claude-3-opus@20240229", key, false)
|
await this.invokeModel("claude-3-opus@20240229", key, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||||
this.log.debug(
|
this.log.debug({ key: key.hash }, "GCP key check complete.");
|
||||||
{ key: key.hash},
|
|
||||||
"GCP key check complete."
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
this.log.info(
|
this.log.info(
|
||||||
@@ -134,8 +133,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
*/
|
*/
|
||||||
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
|
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
|
||||||
const creds = GcpKeyChecker.getCredentialsFromKey(key);
|
const creds = GcpKeyChecker.getCredentialsFromKey(key);
|
||||||
const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey)
|
const signedJWT = await GcpKeyChecker.createSignedJWT(
|
||||||
const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT)
|
creds.clientEmail,
|
||||||
|
creds.privateKey
|
||||||
|
);
|
||||||
|
const [accessToken, jwtError] =
|
||||||
|
await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT);
|
||||||
if (accessToken === null) {
|
if (accessToken === null) {
|
||||||
this.log.warn(
|
this.log.warn(
|
||||||
{ key: key.hash, jwtError },
|
{ key: key.hash, jwtError },
|
||||||
@@ -151,15 +154,19 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
const { data, status } = await axios.post(
|
const { data, status } = await axios.post(
|
||||||
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
|
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
|
||||||
payload,
|
payload,
|
||||||
{
|
{
|
||||||
headers: GcpKeyChecker.getRequestHeaders(accessToken),
|
headers: GcpKeyChecker.getRequestHeaders(accessToken),
|
||||||
validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300
|
validateStatus: initial
|
||||||
|
? () => true
|
||||||
|
: (status: number) => status >= 200 && status < 300,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
this.log.debug({ key: key.hash, data }, "Response from GCP");
|
this.log.debug({ key: key.hash, data }, "Response from GCP");
|
||||||
|
|
||||||
if (initial) {
|
if (initial) {
|
||||||
return (status >= 200 && status < 300) || (status === 429 || status === 529);
|
return (
|
||||||
|
(status >= 200 && status < 300) || status === 429 || status === 529
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@@ -178,10 +185,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
let cryptoKey = await crypto.subtle.importKey(
|
let cryptoKey = await crypto.subtle.importKey(
|
||||||
"pkcs8",
|
"pkcs8",
|
||||||
GcpKeyChecker.str2ab(atob(pkey)),
|
GcpKeyChecker.str2ab(atob(pkey)),
|
||||||
{
|
{ name: "RSASSA-PKCS1-v1_5", hash: { name: "SHA-256" } },
|
||||||
name: "RSASSA-PKCS1-v1_5",
|
|
||||||
hash: { name: "SHA-256" },
|
|
||||||
},
|
|
||||||
false,
|
false,
|
||||||
["sign"]
|
["sign"]
|
||||||
);
|
);
|
||||||
@@ -190,10 +194,7 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
const issued = Math.floor(Date.now() / 1000);
|
const issued = Math.floor(Date.now() / 1000);
|
||||||
const expires = issued + 600;
|
const expires = issued + 600;
|
||||||
|
|
||||||
const header = {
|
const header = { alg: "RS256", typ: "JWT" };
|
||||||
alg: "RS256",
|
|
||||||
typ: "JWT",
|
|
||||||
};
|
|
||||||
|
|
||||||
const payload = {
|
const payload = {
|
||||||
iss: email,
|
iss: email,
|
||||||
@@ -203,8 +204,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
scope: "https://www.googleapis.com/auth/cloud-platform",
|
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||||
};
|
};
|
||||||
|
|
||||||
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header));
|
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(
|
||||||
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload));
|
JSON.stringify(header)
|
||||||
|
);
|
||||||
|
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(
|
||||||
|
JSON.stringify(payload)
|
||||||
|
);
|
||||||
|
|
||||||
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||||
|
|
||||||
@@ -218,7 +223,9 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
return `${unsignedToken}.${encodedSignature}`;
|
return `${unsignedToken}.${encodedSignature}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> {
|
static async exchangeJwtForAccessToken(
|
||||||
|
signed_jwt: string
|
||||||
|
): Promise<[string | null, string]> {
|
||||||
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
|
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
|
||||||
const params = {
|
const params = {
|
||||||
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||||
@@ -252,7 +259,11 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||||
let base64: string;
|
let base64: string;
|
||||||
if (typeof data === "string") {
|
if (typeof data === "string") {
|
||||||
base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16))));
|
base64 = btoa(
|
||||||
|
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
|
||||||
|
String.fromCharCode(parseInt("0x" + p1, 16))
|
||||||
|
)
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||||
}
|
}
|
||||||
@@ -260,7 +271,10 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static getRequestHeaders(accessToken: string) {
|
static getRequestHeaders(accessToken: string) {
|
||||||
return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" };
|
return {
|
||||||
|
Authorization: `Bearer ${accessToken}`,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static getCredentialsFromKey(key: GcpKey) {
|
static getCredentialsFromKey(key: GcpKey) {
|
||||||
@@ -269,9 +283,12 @@ export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
|||||||
throw new Error("Invalid GCP key");
|
throw new Error("Invalid GCP key");
|
||||||
}
|
}
|
||||||
const privateKey = rawPrivateKey
|
const privateKey = rawPrivateKey
|
||||||
.replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '')
|
.replace(
|
||||||
|
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
|
||||||
|
""
|
||||||
|
)
|
||||||
.trim();
|
.trim();
|
||||||
|
|
||||||
return { projectId, clientEmail, region, privateKey };
|
return { projectId, clientEmail, region, privateKey };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ type KeyCheckerOptions<TKey extends Key = Key> = {
|
|||||||
service: string;
|
service: string;
|
||||||
keyCheckPeriod: number;
|
keyCheckPeriod: number;
|
||||||
minCheckInterval: number;
|
minCheckInterval: number;
|
||||||
|
keyCheckBatchSize?: number;
|
||||||
recurringChecksEnabled?: boolean;
|
recurringChecksEnabled?: boolean;
|
||||||
updateKey: (hash: string, props: Partial<TKey>) => void;
|
updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||||
};
|
};
|
||||||
@@ -22,6 +23,8 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
|||||||
* than this.
|
* than this.
|
||||||
*/
|
*/
|
||||||
protected readonly keyCheckPeriod: number;
|
protected readonly keyCheckPeriod: number;
|
||||||
|
/** Maximum number of keys to check simultaneously. */
|
||||||
|
protected readonly keyCheckBatchSize: number;
|
||||||
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
|
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||||
protected readonly keys: TKey[] = [];
|
protected readonly keys: TKey[] = [];
|
||||||
protected log: pino.Logger;
|
protected log: pino.Logger;
|
||||||
@@ -33,6 +36,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
|||||||
this.keyCheckPeriod = opts.keyCheckPeriod;
|
this.keyCheckPeriod = opts.keyCheckPeriod;
|
||||||
this.minCheckInterval = opts.minCheckInterval;
|
this.minCheckInterval = opts.minCheckInterval;
|
||||||
this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true;
|
this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true;
|
||||||
|
this.keyCheckBatchSize = opts.keyCheckBatchSize ?? 12;
|
||||||
this.updateKey = opts.updateKey;
|
this.updateKey = opts.updateKey;
|
||||||
this.service = opts.service;
|
this.service = opts.service;
|
||||||
this.log = logger.child({ module: "key-checker", service: opts.service });
|
this.log = logger.child({ module: "key-checker", service: opts.service });
|
||||||
@@ -78,7 +82,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
|||||||
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
|
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
|
||||||
|
|
||||||
if (numUnchecked > 0) {
|
if (numUnchecked > 0) {
|
||||||
const keycheckBatch = uncheckedKeys.slice(0, 12);
|
const keycheckBatch = uncheckedKeys.slice(0, this.keyCheckBatchSize);
|
||||||
|
|
||||||
this.timeout = setTimeout(async () => {
|
this.timeout = setTimeout(async () => {
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -1,12 +1,22 @@
|
|||||||
import { Key } from "./index";
|
import { Key } from "./index";
|
||||||
|
|
||||||
export function prioritizeKeys<T extends Key>(keys: T[]) {
|
/**
|
||||||
// Sorts keys from highest priority to lowest priority, where priority is:
|
* Given a list of keys, returns a new list of keys sorted from highest to
|
||||||
// 1. Keys which are not rate limited
|
* lowest priority. Keys are prioritized in the following order:
|
||||||
// a. If all keys were rate limited recently, select the least-recently
|
*
|
||||||
// rate limited key.
|
* 1. Keys which are not rate limited
|
||||||
// 2. Keys which have not been used in the longest time
|
* a. If all keys were rate limited recently, select the least-recently
|
||||||
|
* rate limited key.
|
||||||
|
* b. Otherwise, select the first key.
|
||||||
|
* 2. Keys which have not been used in the longest time
|
||||||
|
* 3. Keys according to the custom comparator, if provided
|
||||||
|
* @param keys The list of keys to sort
|
||||||
|
* @param customComparator A custom comparator function to use for sorting
|
||||||
|
*/
|
||||||
|
export function prioritizeKeys<T extends Key>(
|
||||||
|
keys: T[],
|
||||||
|
customComparator?: (a: T, b: T) => number
|
||||||
|
) {
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
|
|
||||||
return keys.sort((a, b) => {
|
return keys.sort((a, b) => {
|
||||||
@@ -19,6 +29,11 @@ export function prioritizeKeys<T extends Key>(keys: T[]) {
|
|||||||
return a.rateLimitedAt - b.rateLimitedAt;
|
return a.rateLimitedAt - b.rateLimitedAt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (customComparator) {
|
||||||
|
const result = customComparator(a, b);
|
||||||
|
if (result !== 0) return result;
|
||||||
|
}
|
||||||
|
|
||||||
return a.lastUsed - b.lastUsed;
|
return a.lastUsed - b.lastUsed;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ async function getTokenCountForMessages({
|
|||||||
case "image":
|
case "image":
|
||||||
numTokens += await getImageTokenCount(part.source.data);
|
numTokens += await getImageTokenCount(part.source.data);
|
||||||
break;
|
break;
|
||||||
|
case "tool_use":
|
||||||
|
case "tool_result":
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported Anthropic content type.`);
|
throw new Error(`Unsupported Anthropic content type.`);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import cookieParser from "cookie-parser";
|
import cookieParser from "cookie-parser";
|
||||||
import expressSession from "express-session";
|
import expressSession from "express-session";
|
||||||
import MemoryStore from "memorystore";
|
import MemoryStore from "memorystore";
|
||||||
import { config, COOKIE_SECRET } from "../config";
|
import { config, SECRET_SIGNING_KEY } from "../config";
|
||||||
|
|
||||||
const ONE_WEEK = 1000 * 60 * 60 * 24 * 7;
|
const ONE_WEEK = 1000 * 60 * 60 * 24 * 7;
|
||||||
|
|
||||||
const cookieParserMiddleware = cookieParser(COOKIE_SECRET);
|
const cookieParserMiddleware = cookieParser(SECRET_SIGNING_KEY);
|
||||||
|
|
||||||
const sessionMiddleware = expressSession({
|
const sessionMiddleware = expressSession({
|
||||||
secret: COOKIE_SECRET,
|
secret: SECRET_SIGNING_KEY,
|
||||||
resave: false,
|
resave: false,
|
||||||
saveUninitialized: false,
|
saveUninitialized: false,
|
||||||
store: new (MemoryStore(expressSession))({ checkPeriod: ONE_WEEK }),
|
store: new (MemoryStore(expressSession))({ checkPeriod: ONE_WEEK }),
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import crypto from "crypto";
|
|||||||
import express from "express";
|
import express from "express";
|
||||||
import argon2 from "@node-rs/argon2";
|
import argon2 from "@node-rs/argon2";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
import { signMessage } from "../../shared/hmac-signing";
|
||||||
import {
|
import {
|
||||||
authenticate,
|
authenticate,
|
||||||
createUser,
|
createUser,
|
||||||
@@ -13,15 +14,13 @@ import { config } from "../../config";
|
|||||||
/** Lockout time after verification in milliseconds */
|
/** Lockout time after verification in milliseconds */
|
||||||
const LOCKOUT_TIME = 1000 * 60; // 60 seconds
|
const LOCKOUT_TIME = 1000 * 60; // 60 seconds
|
||||||
|
|
||||||
/** HMAC key for signing challenges; regenerated on startup */
|
let powKeySalt = crypto.randomBytes(32).toString("hex");
|
||||||
let hmacSecret = crypto.randomBytes(32).toString("hex");
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Regenerate the HMAC key used for signing challenges. Calling this function
|
* Invalidates any outstanding unsolved challenges.
|
||||||
* will invalidate all existing challenges.
|
|
||||||
*/
|
*/
|
||||||
export function invalidatePowHmacKey() {
|
export function invalidatePowChallenges() {
|
||||||
hmacSecret = crypto.randomBytes(32).toString("hex");
|
powKeySalt = crypto.randomBytes(32).toString("hex");
|
||||||
}
|
}
|
||||||
|
|
||||||
const argon2Params = {
|
const argon2Params = {
|
||||||
@@ -141,16 +140,6 @@ function generateChallenge(clientIp?: string, token?: string): Challenge {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function signMessage(msg: any): string {
|
|
||||||
const hmac = crypto.createHmac("sha256", hmacSecret);
|
|
||||||
if (typeof msg === "object") {
|
|
||||||
hmac.update(JSON.stringify(msg));
|
|
||||||
} else {
|
|
||||||
hmac.update(msg);
|
|
||||||
}
|
|
||||||
return hmac.digest("hex");
|
|
||||||
}
|
|
||||||
|
|
||||||
async function verifySolution(
|
async function verifySolution(
|
||||||
challenge: Challenge,
|
challenge: Challenge,
|
||||||
solution: string,
|
solution: string,
|
||||||
@@ -225,11 +214,11 @@ router.post("/challenge", (req, res) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const challenge = generateChallenge(req.ip, refreshToken);
|
const challenge = generateChallenge(req.ip, refreshToken);
|
||||||
const signature = signMessage(challenge);
|
const signature = signMessage(challenge, powKeySalt);
|
||||||
res.json({ challenge, signature });
|
res.json({ challenge, signature });
|
||||||
} else {
|
} else {
|
||||||
const challenge = generateChallenge(req.ip);
|
const challenge = generateChallenge(req.ip);
|
||||||
const signature = signMessage(challenge);
|
const signature = signMessage(challenge, powKeySalt);
|
||||||
res.json({ challenge, signature });
|
res.json({ challenge, signature });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -253,7 +242,7 @@ router.post("/verify", async (req, res) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const { challenge, signature, solution } = result.data;
|
const { challenge, signature, solution } = result.data;
|
||||||
if (signMessage(challenge) !== signature) {
|
if (signMessage(challenge, powKeySalt) !== signature) {
|
||||||
res.status(400).json({
|
res.status(400).json({
|
||||||
error:
|
error:
|
||||||
"Invalid signature; server may have restarted since challenge was issued. Please request a new challenge.",
|
"Invalid signature; server may have restarted since challenge was issued. Please request a new challenge.",
|
||||||
|
|||||||
Reference in New Issue
Block a user