20 Commits

Author SHA1 Message Date
nai-degen 49aabddd71 test 2024-01-07 19:11:33 -06:00
nai-degen 7b0892ddae fixes unawaited call to async enqueue 2024-01-07 16:23:53 -06:00
nai-degen 7f92565739 SSE queueing adjustments, untested 2024-01-07 16:19:22 -06:00
nai-degen 936d3c0721 corrects nodejs max heap memory config 2024-01-07 16:16:27 -06:00
nai-degen 4ffa7fb12b reduces max request body size for now 2024-01-07 13:03:24 -06:00
nai-degen 8dc7464381 strips extraneous properties on zod schemas 2024-01-07 13:00:48 -06:00
nai-degen d2cd24bfd2 suggest larger nodejs max heap 2024-01-07 12:58:50 -06:00
twinkletoes e33f778192 Change mistral-medium friendly name (khanon/oai-reverse-proxy!59) 2023-12-26 00:27:17 +00:00
twinkletoes 4a823b216f Mistral AI support (khanon/oai-reverse-proxy!58) 2023-12-25 18:33:16 +00:00
nai-degen 01e76cbb1c restores accidentally deleted line breaking infopage stats 2023-12-17 00:25:58 -06:00
nai-degen 655703e680 refactors infopage 2023-12-16 20:30:20 -06:00
nai-degen 3be2687793 tries to detect Azure GPT4-Turbo deployments more reliably 2023-12-15 12:14:23 -06:00
nai-degen 5599a83ae4 improves streaming error handling 2023-12-14 05:01:10 -06:00
nai-degen de34d41918 fixes gemini name prefixing when 'Add character names' is disabled in ST 2023-12-13 23:21:30 -06:00
nai-degen c5cd90dcef adjusts prompt transform to discourage Gemini from speaking for user 2023-12-13 23:03:57 -06:00
nai-degen 8a135a960d fixes gemini prompt reformatting for jbs; adds stop sequences 2023-12-13 21:45:53 -06:00
nai-degen 707cbbce16 fixes gemini throwing an error on JB prompts 2023-12-13 19:14:31 -06:00
khanon fad16cc268 Add Google AI API (khanon/oai-reverse-proxy!57) 2023-12-13 21:56:07 +00:00
nai-degen 0d3682197c treats 403 from anthropic as key dead 2023-12-11 09:13:53 -06:00
valadaptive e0624e30fd Fix some corner cases in SSE parsing (khanon/oai-reverse-proxy!56) 2023-12-09 06:18:01 +00:00
60 changed files with 2654 additions and 1130 deletions
+3 -3
View File
@@ -34,10 +34,10 @@
# Which model types users are allowed to access. # Which model types users are allowed to access.
# The following model families are recognized: # The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo # turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | mistral-tiny | mistral-small | mistral-medium | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image # By default, all models are allowed except for 'dall-e'. To allow DALL-E image
# generation, uncomment the line below and add 'dall-e' to the list. # generation, uncomment the line below and add 'dall-e' to the list.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo # ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,mistral-tiny,mistral-small,mistral-medium,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# URLs from which requests will be blocked. # URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com # BLOCKED_ORIGINS=reddit.com,9gag.com
@@ -95,7 +95,7 @@
# TOKEN_QUOTA_GPT4_TURBO=0 # TOKEN_QUOTA_GPT4_TURBO=0
# TOKEN_QUOTA_DALL_E=0 # TOKEN_QUOTA_DALL_E=0
# TOKEN_QUOTA_CLAUDE=0 # TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_BISON=0 # TOKEN_QUOTA_GEMINI_PRO=0
# TOKEN_QUOTA_AWS_CLAUDE=0 # TOKEN_QUOTA_AWS_CLAUDE=0
# How often to refresh token quotas. (hourly | daily) # How often to refresh token quotas. (hourly | daily)
+2
View File
@@ -10,4 +10,6 @@ COPY Dockerfile greeting.md* .env* ./
RUN npm run build RUN npm run build
EXPOSE 7860 EXPOSE 7860
ENV NODE_ENV=production ENV NODE_ENV=production
# Huggigface free VMs have 16GB of RAM so we can be greedy
ENV NODE_OPTIONS="--max-old-space-size=12882"
CMD [ "npm", "start" ] CMD [ "npm", "start" ]
+1 -1
View File
@@ -35,7 +35,7 @@ Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL
ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
# All models as of this writing # All models as of this writing
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,dall-e
``` ```
Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models. Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
+1
View File
@@ -32,6 +32,7 @@ COPY Dockerfile greeting.md* .env* ./
RUN npm run build RUN npm run build
EXPOSE 7860 EXPOSE 7860
ENV NODE_ENV=production ENV NODE_ENV=production
ENV NODE_OPTIONS="--max-old-space-size=12882"
CMD [ "npm", "start" ] CMD [ "npm", "start" ]
``` ```
- Click "Commit new file to `main`" to save the Dockerfile. - Click "Commit new file to `main`" to save the Dockerfile.
+34
View File
@@ -36,6 +36,7 @@
"sanitize-html": "^2.11.0", "sanitize-html": "^2.11.0",
"sharp": "^0.32.6", "sharp": "^0.32.6",
"showdown": "^2.1.0", "showdown": "^2.1.0",
"stream-json": "^1.8.0",
"tiktoken": "^1.0.10", "tiktoken": "^1.0.10",
"uuid": "^9.0.0", "uuid": "^9.0.0",
"zlib": "^1.0.5", "zlib": "^1.0.5",
@@ -51,6 +52,7 @@
"@types/node-schedule": "^2.1.0", "@types/node-schedule": "^2.1.0",
"@types/sanitize-html": "^2.9.0", "@types/sanitize-html": "^2.9.0",
"@types/showdown": "^2.0.0", "@types/showdown": "^2.0.0",
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1", "@types/uuid": "^9.0.1",
"concurrently": "^8.0.1", "concurrently": "^8.0.1",
"esbuild": "^0.17.16", "esbuild": "^0.17.16",
@@ -1185,6 +1187,25 @@
"integrity": "sha512-70xBJoLv+oXjB5PhtA8vo7erjLDp9/qqI63SRHm4REKrwuPOLs8HhXwlZJBJaB4kC18cCZ1UUZ6Fb/PLFW4TCA==", "integrity": "sha512-70xBJoLv+oXjB5PhtA8vo7erjLDp9/qqI63SRHm4REKrwuPOLs8HhXwlZJBJaB4kC18cCZ1UUZ6Fb/PLFW4TCA==",
"dev": true "dev": true
}, },
"node_modules/@types/stream-chain": {
"version": "2.0.4",
"resolved": "https://registry.npmjs.org/@types/stream-chain/-/stream-chain-2.0.4.tgz",
"integrity": "sha512-V7TsWLHrx79KumkHqSD7F8eR6POpEuWb6PuXJ7s/dRHAf3uVst3Jkp1yZ5XqIfECZLQ4a28vBVstTErmsMBvaQ==",
"dev": true,
"dependencies": {
"@types/node": "*"
}
},
"node_modules/@types/stream-json": {
"version": "1.7.7",
"resolved": "https://registry.npmjs.org/@types/stream-json/-/stream-json-1.7.7.tgz",
"integrity": "sha512-hHG7cLQ09H/m9i0jzL6UJAeLLxIWej90ECn0svO4T8J0nGcl89xZDQ2ujT4WKlvg0GWkcxJbjIDzW/v7BYUM6Q==",
"dev": true,
"dependencies": {
"@types/node": "*",
"@types/stream-chain": "*"
}
},
"node_modules/@types/uuid": { "node_modules/@types/uuid": {
"version": "9.0.1", "version": "9.0.1",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.1.tgz", "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.1.tgz",
@@ -5135,6 +5156,11 @@
"node": ">= 0.8" "node": ">= 0.8"
} }
}, },
"node_modules/stream-chain": {
"version": "2.2.5",
"resolved": "https://registry.npmjs.org/stream-chain/-/stream-chain-2.2.5.tgz",
"integrity": "sha512-1TJmBx6aSWqZ4tx7aTpBDXK0/e2hhcNSTV8+CbFJtDjbb+I1mZ8lHit0Grw9GRT+6JbIrrDd8esncgBi8aBXGA=="
},
"node_modules/stream-events": { "node_modules/stream-events": {
"version": "1.0.5", "version": "1.0.5",
"resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz", "resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz",
@@ -5144,6 +5170,14 @@
"stubs": "^3.0.0" "stubs": "^3.0.0"
} }
}, },
"node_modules/stream-json": {
"version": "1.8.0",
"resolved": "https://registry.npmjs.org/stream-json/-/stream-json-1.8.0.tgz",
"integrity": "sha512-HZfXngYHUAr1exT4fxlbc1IOce1RYxp2ldeaf97LYCOPSoOqY/1Psp7iGvpb+6JIOgkra9zDYnPX01hGAHzEPw==",
"dependencies": {
"stream-chain": "^2.2.5"
}
},
"node_modules/stream-shift": { "node_modules/stream-shift": {
"version": "1.0.1", "version": "1.0.1",
"resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.1.tgz", "resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.1.tgz",
+2
View File
@@ -44,6 +44,7 @@
"sanitize-html": "^2.11.0", "sanitize-html": "^2.11.0",
"sharp": "^0.32.6", "sharp": "^0.32.6",
"showdown": "^2.1.0", "showdown": "^2.1.0",
"stream-json": "^1.8.0",
"tiktoken": "^1.0.10", "tiktoken": "^1.0.10",
"uuid": "^9.0.0", "uuid": "^9.0.0",
"zlib": "^1.0.5", "zlib": "^1.0.5",
@@ -59,6 +60,7 @@
"@types/node-schedule": "^2.1.0", "@types/node-schedule": "^2.1.0",
"@types/sanitize-html": "^2.9.0", "@types/sanitize-html": "^2.9.0",
"@types/showdown": "^2.0.0", "@types/showdown": "^2.0.0",
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1", "@types/uuid": "^9.0.1",
"concurrently": "^8.0.1", "concurrently": "^8.0.1",
"esbuild": "^0.17.16", "esbuild": "^0.17.16",
+31 -3
View File
@@ -81,7 +81,7 @@ Authorization: Bearer {{proxy-key}}
Content-Type: application/json Content-Type: application/json
{ {
"model": "gpt-3.5-turbo", "model": "gpt-4-1106-preview",
"max_tokens": 20, "max_tokens": 20,
"stream": true, "stream": true,
"temperature": 1, "temperature": 1,
@@ -231,8 +231,36 @@ Content-Type: application/json
} }
### ###
# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation # @name Proxy / Azure OpenAI -- Native Chat Completions
POST {{proxy-host}}/proxy/google-palm/v1/chat/completions POST {{proxy-host}}/proxy/azure/openai/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-4",
"max_tokens": 20,
"stream": true,
"temperature": 1,
"seed": 2,
"messages": [
{
"role": "user",
"content": "Hi what is the name of the fourth president of the united states?"
},
{
"role": "assistant",
"content": "That would be George Washington."
},
{
"role": "user",
"content": "That's not right."
}
]
}
###
# @name Proxy / Google AI -- OpenAI-to-Google AI API Translation
POST {{proxy-host}}/proxy/google-ai/v1/chat/completions
Authorization: Bearer {{proxy-key}} Authorization: Bearer {{proxy-key}}
Content-Type: application/json Content-Type: application/json
+4 -3
View File
@@ -1,6 +1,6 @@
const axios = require("axios"); const axios = require("axios");
const concurrentRequests = 5; const concurrentRequests = 75;
const headers = { const headers = {
Authorization: "Bearer test", Authorization: "Bearer test",
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -16,7 +16,7 @@ const payload = {
const makeRequest = async (i) => { const makeRequest = async (i) => {
try { try {
const response = await axios.post( const response = await axios.post(
"http://localhost:7860/proxy/azure/openai/v1/chat/completions", "http://localhost:7860/proxy/google-ai/v1/chat/completions",
payload, payload,
{ headers } { headers }
); );
@@ -25,7 +25,8 @@ const makeRequest = async (i) => {
response.data response.data
); );
} catch (error) { } catch (error) {
console.error(`Error in req ${i}:`, error.message); const msg = error.response
console.error(`Error in req ${i}:`, error.message, msg || "");
} }
}; };
+3 -2
View File
@@ -4,7 +4,8 @@ import { HttpError } from "../shared/errors";
import { injectLocals } from "../shared/inject-locals"; import { injectLocals } from "../shared/inject-locals";
import { withSession } from "../shared/with-session"; import { withSession } from "../shared/with-session";
import { injectCsrfToken, checkCsrfToken } from "../shared/inject-csrf"; import { injectCsrfToken, checkCsrfToken } from "../shared/inject-csrf";
import { buildInfoPageHtml } from "../info-page"; import { renderPage } from "../info-page";
import { buildInfo } from "../service-info";
import { loginRouter } from "./login"; import { loginRouter } from "./login";
import { usersApiRouter as apiRouter } from "./api/users"; import { usersApiRouter as apiRouter } from "./api/users";
import { usersWebRouter as webRouter } from "./web/manage"; import { usersWebRouter as webRouter } from "./web/manage";
@@ -26,7 +27,7 @@ adminRouter.use("/", loginRouter);
adminRouter.use("/manage", authorize({ via: "cookie" }), webRouter); adminRouter.use("/manage", authorize({ via: "cookie" }), webRouter);
adminRouter.use("/service-info", authorize({ via: "cookie" }), (req, res) => { adminRouter.use("/service-info", authorize({ via: "cookie" }), (req, res) => {
return res.send( return res.send(
buildInfoPageHtml(req.protocol + "://" + req.get("host"), true) renderPage(buildInfo(req.protocol + "://" + req.get("host"), true))
); );
}); });
+1 -1
View File
@@ -200,7 +200,7 @@ router.post("/maintenance", (req, res) => {
keyPool.recheck("anthropic"); keyPool.recheck("anthropic");
const size = keyPool const size = keyPool
.list() .list()
.filter((k) => k.service !== "google-palm").length; .filter((k) => k.service !== "google-ai").length;
flash.type = "success"; flash.type = "success";
flash.message = `Scheduled recheck of ${size} keys for OpenAI and Anthropic.`; flash.message = `Scheduled recheck of ${size} keys for OpenAI and Anthropic.`;
break; break;
+41 -15
View File
@@ -4,6 +4,7 @@ import path from "path";
import pino from "pino"; import pino from "pino";
import type { ModelFamily } from "./shared/models"; import type { ModelFamily } from "./shared/models";
import { MODEL_FAMILIES } from "./shared/models"; import { MODEL_FAMILIES } from "./shared/models";
dotenv.config(); dotenv.config();
const startupLogger = pino({ level: "debug" }).child({ module: "startup" }); const startupLogger = pino({ level: "debug" }).child({ module: "startup" });
@@ -19,8 +20,16 @@ type Config = {
openaiKey?: string; openaiKey?: string;
/** Comma-delimited list of Anthropic API keys. */ /** Comma-delimited list of Anthropic API keys. */
anthropicKey?: string; anthropicKey?: string;
/** Comma-delimited list of Google PaLM API keys. */ /**
googlePalmKey?: string; * Comma-delimited list of Google AI API keys. Note that these are not the
* same as the GCP keys/credentials used for Vertex AI; the models are the
* same but the APIs are different. Vertex is the GCP product for enterprise.
**/
googleAIKey?: string;
/**
* Comma-delimited list of Mistral AI API keys.
*/
mistralAIKey?: string;
/** /**
* Comma-delimited list of AWS credentials. Each credential item should be a * Comma-delimited list of AWS credentials. Each credential item should be a
* colon-delimited list of access key, secret key, and AWS region. * colon-delimited list of access key, secret key, and AWS region.
@@ -197,7 +206,8 @@ export const config: Config = {
port: getEnvWithDefault("PORT", 7860), port: getEnvWithDefault("PORT", 7860),
openaiKey: getEnvWithDefault("OPENAI_KEY", ""), openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""), anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""), googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""), awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""),
@@ -229,7 +239,10 @@ export const config: Config = {
"gpt4-32k", "gpt4-32k",
"gpt4-turbo", "gpt4-turbo",
"claude", "claude",
"bison", "gemini-pro",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"aws-claude", "aws-claude",
"azure-turbo", "azure-turbo",
"azure-gpt4", "azure-gpt4",
@@ -361,12 +374,13 @@ export const SENSITIVE_KEYS: (keyof Config)[] = ["googleSheetsSpreadsheetId"];
* Config keys that are not displayed on the info page at all, generally because * Config keys that are not displayed on the info page at all, generally because
* they are not relevant to the user or can be inferred from other config. * they are not relevant to the user or can be inferred from other config.
*/ */
export const OMITTED_KEYS: (keyof Config)[] = [ export const OMITTED_KEYS = [
"port", "port",
"logLevel", "logLevel",
"openaiKey", "openaiKey",
"anthropicKey", "anthropicKey",
"googlePalmKey", "googleAIKey",
"mistralAIKey",
"awsCredentials", "awsCredentials",
"azureCredentials", "azureCredentials",
"proxyKey", "proxyKey",
@@ -387,34 +401,46 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"staticServiceInfo", "staticServiceInfo",
"checkKeys", "checkKeys",
"allowedModelFamilies", "allowedModelFamilies",
]; ] satisfies (keyof Config)[];
type OmitKeys = (typeof OMITTED_KEYS)[number];
type Printable<T> = {
[P in keyof T as Exclude<P, OmitKeys>]: T[P] extends object
? Printable<T[P]>
: string;
};
type PublicConfig = Printable<Config>;
const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>; const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>;
export function listConfig(obj: Config = config): Record<string, any> { export function listConfig(obj: Config = config) {
const result: Record<string, any> = {}; const result: Record<string, unknown> = {};
for (const key of getKeys(obj)) { for (const key of getKeys(obj)) {
const value = obj[key]?.toString() || ""; const value = obj[key]?.toString() || "";
const shouldOmit =
OMITTED_KEYS.includes(key) || value === "" || value === "undefined";
const shouldMask = SENSITIVE_KEYS.includes(key); const shouldMask = SENSITIVE_KEYS.includes(key);
const shouldOmit =
OMITTED_KEYS.includes(key as OmitKeys) ||
value === "" ||
value === "undefined";
if (shouldOmit) { if (shouldOmit) {
continue; continue;
} }
const validKey = key as keyof Printable<Config>;
if (value && shouldMask) { if (value && shouldMask) {
result[key] = "********"; result[validKey] = "********";
} else { } else {
result[key] = value; result[validKey] = value;
} }
if (typeof obj[key] === "object" && !Array.isArray(obj[key])) { if (typeof obj[key] === "object" && !Array.isArray(obj[key])) {
result[key] = listConfig(obj[key] as unknown as Config); result[key] = listConfig(obj[key] as unknown as Config);
} }
} }
return result; return result as PublicConfig;
} }
/** /**
@@ -433,7 +459,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
[ [
"OPENAI_KEY", "OPENAI_KEY",
"ANTHROPIC_KEY", "ANTHROPIC_KEY",
"GOOGLE_PALM_KEY", "GOOGLE_AI_KEY",
"AWS_CREDENTIALS", "AWS_CREDENTIALS",
"AZURE_CREDENTIALS", "AZURE_CREDENTIALS",
].includes(String(env)) ].includes(String(env))
+42 -500
View File
@@ -1,74 +1,39 @@
/** This whole module really sucks */ /** This whole module kinda sucks */
import fs from "fs"; import fs from "fs";
import { Request, Response } from "express"; import { Request, Response } from "express";
import showdown from "showdown"; import showdown from "showdown";
import { config, listConfig } from "./config"; import { config } from "./config";
import { import { buildInfo, ServiceInfo } from "./service-info";
AnthropicKey,
AwsBedrockKey,
AzureOpenAIKey,
GooglePalmKey,
keyPool,
OpenAIKey,
} from "./shared/key-management";
import {
AzureOpenAIModelFamily,
ModelFamily,
OpenAIModelFamily,
} from "./shared/models";
import { getUniqueIps } from "./proxy/rate-limit";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
import { assertNever } from "./shared/utils";
import { getLastNImages } from "./shared/file-storage/image-history"; import { getLastNImages } from "./shared/file-storage/image-history";
import { keyPool } from "./shared/key-management";
import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models";
const INFO_PAGE_TTL = 2000; const INFO_PAGE_TTL = 2000;
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"turbo": "GPT-3.5 Turbo",
"gpt4": "GPT-4",
"gpt4-32k": "GPT-4 32k",
"gpt4-turbo": "GPT-4 Turbo",
"dall-e": "DALL-E",
"claude": "Claude",
"gemini-pro": "Gemini Pro",
"mistral-tiny": "Mistral 7B",
"mistral-small": "Mixtral 8x7B",
"mistral-medium": "Mistral Medium (prototype)",
"aws-claude": "AWS Claude",
"azure-turbo": "Azure GPT-3.5 Turbo",
"azure-gpt4": "Azure GPT-4",
"azure-gpt4-32k": "Azure GPT-4 32k",
"azure-gpt4-turbo": "Azure GPT-4 Turbo",
};
const converter = new showdown.Converter();
const customGreeting = fs.existsSync("greeting.md")
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
: "";
let infoPageHtml: string | undefined; let infoPageHtml: string | undefined;
let infoPageLastUpdated = 0; let infoPageLastUpdated = 0;
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
k.service === "openai";
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic";
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
k.service === "google-palm";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
type ModelAggregates = {
active: number;
trial?: number;
revoked?: number;
overQuota?: number;
pozzed?: number;
awsLogged?: number;
queued: number;
queueTime: string;
tokens: number;
};
type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`;
type ServiceAggregates = {
status?: string;
openaiKeys?: number;
openaiOrgs?: number;
anthropicKeys?: number;
palmKeys?: number;
awsKeys?: number;
azureKeys?: number;
proompts: number;
tokens: number;
tokenCost: number;
openAiUncheckedKeys?: number;
anthropicUncheckedKeys?: number;
} & {
[modelFamily in ModelFamily]?: ModelAggregates;
};
const modelStats = new Map<ModelAggregateKey, number>();
const serviceStats = new Map<keyof ServiceAggregates, number>();
export const handleInfoPage = (req: Request, res: Response) => { export const handleInfoPage = (req: Request, res: Response) => {
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) { if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
return res.send(infoPageHtml); return res.send(infoPageHtml);
@@ -79,87 +44,16 @@ export const handleInfoPage = (req: Request, res: Response) => {
? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID) ? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID)
: req.protocol + "://" + req.get("host"); : req.protocol + "://" + req.get("host");
infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy"); const info = buildInfo(baseUrl + "/proxy");
infoPageHtml = renderPage(info);
infoPageLastUpdated = Date.now(); infoPageLastUpdated = Date.now();
res.send(infoPageHtml); res.send(infoPageHtml);
}; };
function getCostString(cost: number) { export function renderPage(info: ServiceInfo) {
if (!config.showTokenCosts) return "";
return ` ($${cost.toFixed(2)})`;
}
export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
const keys = keyPool.list();
const hideFullInfo = config.staticServiceInfo && !asAdmin;
modelStats.clear();
serviceStats.clear();
keys.forEach(addKeyToAggregates);
const openaiKeys = serviceStats.get("openaiKeys") || 0;
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
const palmKeys = serviceStats.get("palmKeys") || 0;
const awsKeys = serviceStats.get("awsKeys") || 0;
const azureKeys = serviceStats.get("azureKeys") || 0;
const proompts = serviceStats.get("proompts") || 0;
const tokens = serviceStats.get("tokens") || 0;
const tokenCost = serviceStats.get("tokenCost") || 0;
const allowDalle = config.allowedModelFamilies.includes("dall-e");
const endpoints = {
...(openaiKeys ? { openai: baseUrl + "/openai" } : {}),
...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}),
...(openaiKeys && allowDalle
? { ["openai-image"]: baseUrl + "/openai-image" }
: {}),
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
};
const stats = {
proompts,
tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`,
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
};
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
for (const key of Object.keys(keyInfo)) {
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
}
const providerInfo = {
...(openaiKeys ? getOpenAIInfo() : {}),
...(anthropicKeys ? getAnthropicInfo() : {}),
...(palmKeys ? getPalmInfo() : {}),
...(awsKeys ? getAwsInfo() : {}),
...(azureKeys ? getAzureInfo() : {}),
};
if (hideFullInfo) {
for (const provider of Object.keys(providerInfo)) {
delete (providerInfo as any)[provider].proomptersInQueue;
delete (providerInfo as any)[provider].estimatedQueueTime;
delete (providerInfo as any)[provider].usage;
}
}
const info = {
uptime: Math.floor(process.uptime()),
endpoints,
...(hideFullInfo ? {} : stats),
...keyInfo,
...providerInfo,
config: listConfig(),
build: process.env.BUILD_INFO || "dev",
};
const title = getServerTitle(); const title = getServerTitle();
const headerHtml = buildInfoPageHeader(new showdown.Converter(), title); const headerHtml = buildInfoPageHeader(info);
return `<!DOCTYPE html> return `<!DOCTYPE html>
<html lang="en"> <html lang="en">
@@ -178,324 +72,14 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
</html>`; </html>`;
} }
function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) {
const orgIds = new Set(
keys.filter((k) => k.service === "openai").map((k: any) => k.organizationId)
);
return orgIds.size;
}
function increment<T extends keyof ServiceAggregates | ModelAggregateKey>(
map: Map<T, number>,
key: T,
delta = 1
) {
map.set(key, (map.get(key) || 0) + delta);
}
function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "proompts", k.promptCount);
increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
let sumTokens = 0;
let sumCost = 0;
switch (k.service) {
case "openai":
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
increment(
serviceStats,
"openAiUncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
});
break;
case "azure":
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
});
break;
case "anthropic": {
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
const family = "claude";
sumTokens += k.claudeTokens;
sumCost += getTokenCostUsd(family, k.claudeTokens);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k.claudeTokens);
increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0);
increment(
serviceStats,
"anthropicUncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
break;
}
case "google-palm": {
if (!keyIsGooglePalmKey(k)) throw new Error("Invalid key type");
const family = "bison";
sumTokens += k.bisonTokens;
sumCost += getTokenCostUsd(family, k.bisonTokens);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k.bisonTokens);
break;
}
case "aws": {
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
const family = "aws-claude";
sumTokens += k["aws-claudeTokens"];
sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
// Ignore revoked keys for aws logging stats, but include keys where the
// logging status is unknown.
const countAsLogged =
k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled";
increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0);
break;
}
default:
assertNever(k.service);
}
increment(serviceStats, "tokens", sumTokens);
increment(serviceStats, "tokenCost", sumCost);
}
function getOpenAIInfo() {
const info: { status?: string; openaiKeys?: number; openaiOrgs?: number } & {
[modelFamily in OpenAIModelFamily]?: {
usage?: string;
activeKeys: number;
trialKeys?: number;
revokedKeys?: number;
overQuotaKeys?: number;
proomptersInQueue?: number;
estimatedQueueTime?: string;
};
} = {};
const keys = keyPool.list().filter(keyIsOpenAIKey);
const enabledFamilies = new Set(config.allowedModelFamilies);
const accessibleFamilies = keys
.flatMap((k) => k.modelFamilies)
.filter((f) => enabledFamilies.has(f))
.concat("turbo");
const familySet = new Set(accessibleFamilies);
if (config.checkKeys) {
const unchecked = serviceStats.get("openAiUncheckedKeys") || 0;
if (unchecked > 0) {
info.status = `Checking ${unchecked} keys...`;
}
info.openaiKeys = keys.length;
info.openaiOrgs = getUniqueOpenAIOrgs(keys);
familySet.forEach((f) => {
const tokens = modelStats.get(`${f}__tokens`) || 0;
const cost = getTokenCostUsd(f, tokens);
info[f] = {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: modelStats.get(`${f}__active`) || 0,
trialKeys: modelStats.get(`${f}__trial`) || 0,
revokedKeys: modelStats.get(`${f}__revoked`) || 0,
overQuotaKeys: modelStats.get(`${f}__overQuota`) || 0,
};
// Don't show trial/revoked keys for non-turbo families.
// Generally those stats only make sense for the lowest-tier model.
if (f !== "turbo") {
delete info[f]!.trialKeys;
delete info[f]!.revokedKeys;
}
});
} else {
info.status = "Key checking is disabled.";
info.turbo = { activeKeys: keys.filter((k) => !k.isDisabled).length };
info.gpt4 = {
activeKeys: keys.filter(
(k) => !k.isDisabled && k.modelFamilies.includes("gpt4")
).length,
};
}
familySet.forEach((f) => {
if (enabledFamilies.has(f)) {
if (!info[f]) info[f] = { activeKeys: 0 }; // may occur if checkKeys is disabled
const { estimatedQueueTime, proomptersInQueue } = getQueueInformation(f);
info[f]!.proomptersInQueue = proomptersInQueue;
info[f]!.estimatedQueueTime = estimatedQueueTime;
} else {
(info[f]! as any).status = "GPT-3.5-Turbo is disabled on this proxy.";
}
});
return info;
}
function getAnthropicInfo() {
const claudeInfo: Partial<ModelAggregates> = {
active: modelStats.get("claude__active") || 0,
pozzed: modelStats.get("claude__pozzed") || 0,
revoked: modelStats.get("claude__revoked") || 0,
};
const queue = getQueueInformation("claude");
claudeInfo.queued = queue.proomptersInQueue;
claudeInfo.queueTime = queue.estimatedQueueTime;
const tokens = modelStats.get("claude__tokens") || 0;
const cost = getTokenCostUsd("claude", tokens);
const unchecked =
(config.checkKeys && serviceStats.get("anthropicUncheckedKeys")) || 0;
return {
claude: {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
...(unchecked > 0 ? { status: `Checking ${unchecked} keys...` } : {}),
activeKeys: claudeInfo.active,
revokedKeys: claudeInfo.revoked,
...(config.checkKeys ? { pozzedKeys: claudeInfo.pozzed } : {}),
proomptersInQueue: claudeInfo.queued,
estimatedQueueTime: claudeInfo.queueTime,
},
};
}
function getPalmInfo() {
const bisonInfo: Partial<ModelAggregates> = {
active: modelStats.get("bison__active") || 0,
revoked: modelStats.get("bison__revoked") || 0,
};
const queue = getQueueInformation("bison");
bisonInfo.queued = queue.proomptersInQueue;
bisonInfo.queueTime = queue.estimatedQueueTime;
const tokens = modelStats.get("bison__tokens") || 0;
const cost = getTokenCostUsd("bison", tokens);
return {
bison: {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: bisonInfo.active,
revokedKeys: bisonInfo.revoked,
proomptersInQueue: bisonInfo.queued,
estimatedQueueTime: bisonInfo.queueTime,
},
};
}
function getAwsInfo() {
const awsInfo: Partial<ModelAggregates> = {
active: modelStats.get("aws-claude__active") || 0,
revoked: modelStats.get("aws-claude__revoked") || 0,
};
const queue = getQueueInformation("aws-claude");
awsInfo.queued = queue.proomptersInQueue;
awsInfo.queueTime = queue.estimatedQueueTime;
const tokens = modelStats.get("aws-claude__tokens") || 0;
const cost = getTokenCostUsd("aws-claude", tokens);
const logged = modelStats.get("aws-claude__awsLogged") || 0;
const logMsg = config.allowAwsLogging
? `${logged} active keys are potentially logged.`
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
return {
"aws-claude": {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: awsInfo.active,
revokedKeys: awsInfo.revoked,
proomptersInQueue: awsInfo.queued,
estimatedQueueTime: awsInfo.queueTime,
...(logged > 0 ? { privacy: logMsg } : {}),
},
};
}
function getAzureInfo() {
const azureFamilies = [
"azure-turbo",
"azure-gpt4",
"azure-gpt4-turbo",
"azure-gpt4-32k",
] as const;
const azureInfo: {
[modelFamily in AzureOpenAIModelFamily]?: {
usage?: string;
activeKeys: number;
revokedKeys?: number;
proomptersInQueue?: number;
estimatedQueueTime?: string;
};
} = {};
for (const family of azureFamilies) {
const familyAllowed = config.allowedModelFamilies.includes(family);
const activeKeys = modelStats.get(`${family}__active`) || 0;
if (!familyAllowed || activeKeys === 0) continue;
azureInfo[family] = {
activeKeys,
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
};
const queue = getQueueInformation(family);
azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue;
azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime;
const tokens = modelStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString(
cost
)}`;
}
return azureInfo;
}
const customGreeting = fs.existsSync("greeting.md")
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
: "";
/** /**
* If the server operator provides a `greeting.md` file, it will be included in * If the server operator provides a `greeting.md` file, it will be included in
* the rendered info page. * the rendered info page.
**/ **/
function buildInfoPageHeader(converter: showdown.Converter, title: string) { function buildInfoPageHeader(info: ServiceInfo) {
const title = getServerTitle();
// TODO: use some templating engine instead of this mess // TODO: use some templating engine instead of this mess
let infoBody = `<!-- Header for Showdown's parser, don't remove this line --> let infoBody = `# ${title}`;
# ${title}`;
if (config.promptLogging) { if (config.promptLogging) {
infoBody += `\n## Prompt Logging Enabled infoBody += `\n## Prompt Logging Enabled
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps. This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
@@ -510,45 +94,18 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
} }
const waits: string[] = []; const waits: string[] = [];
infoBody += `\n## Estimated Wait Times`;
if (config.openaiKey) { for (const modelFamily of config.allowedModelFamilies) {
// TODO: un-fuck this const service = MODEL_FAMILY_SERVICE[modelFamily];
const keys = keyPool.list().filter((k) => k.service === "openai");
const turboWait = getQueueInformation("turbo").estimatedQueueTime; const hasKeys = keyPool.list().some((k) => {
waits.push(`**Turbo:** ${turboWait}`); return k.service === service && k.modelFamilies.includes(modelFamily);
});
const gpt4Wait = getQueueInformation("gpt4").estimatedQueueTime; const wait = info[modelFamily]?.estimatedQueueTime;
const hasGpt4 = keys.some((k) => k.modelFamilies.includes("gpt4")); if (hasKeys && wait) {
const allowedGpt4 = config.allowedModelFamilies.includes("gpt4"); waits.push(`**${MODEL_FAMILY_FRIENDLY_NAME[modelFamily] || modelFamily}**: ${wait}`);
if (hasGpt4 && allowedGpt4) {
waits.push(`**GPT-4:** ${gpt4Wait}`);
} }
const gpt432kWait = getQueueInformation("gpt4-32k").estimatedQueueTime;
const hasGpt432k = keys.some((k) => k.modelFamilies.includes("gpt4-32k"));
const allowedGpt432k = config.allowedModelFamilies.includes("gpt4-32k");
if (hasGpt432k && allowedGpt432k) {
waits.push(`**GPT-4-32k:** ${gpt432kWait}`);
}
const dalleWait = getQueueInformation("dall-e").estimatedQueueTime;
const hasDalle = keys.some((k) => k.modelFamilies.includes("dall-e"));
const allowedDalle = config.allowedModelFamilies.includes("dall-e");
if (hasDalle && allowedDalle) {
waits.push(`**DALL-E:** ${dalleWait}`);
}
}
if (config.anthropicKey) {
const claudeWait = getQueueInformation("claude").estimatedQueueTime;
waits.push(`**Claude:** ${claudeWait}`);
}
if (config.awsCredentials) {
const awsClaudeWait = getQueueInformation("aws-claude").estimatedQueueTime;
waits.push(`**Claude (AWS):** ${awsClaudeWait}`);
} }
infoBody += "\n\n" + waits.join(" / "); infoBody += "\n\n" + waits.join(" / ");
@@ -565,21 +122,6 @@ function getSelfServiceLinks() {
return `<footer style="font-size: 0.8em;"><hr /><a target="_blank" href="/user/lookup">Check your user token info</a></footer>`; return `<footer style="font-size: 0.8em;"><hr /><a target="_blank" href="/user/lookup">Check your user token info</a></footer>`;
} }
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
function getQueueInformation(partition: ModelFamily) {
const waitMs = getEstimatedWaitTime(partition);
const waitTime =
waitMs < 60000
? `${Math.round(waitMs / 1000)}sec`
: `${Math.round(waitMs / 60000)}min, ${Math.round(
(waitMs % 60000) / 1000
)}sec`;
return {
proomptersInQueue: getQueueLength(partition),
estimatedQueueTime: waitMs > 2000 ? waitTime : "no wait",
};
}
function getServerTitle() { function getServerTitle() {
// Use manually set title if available // Use manually set title if available
if (process.env.SERVER_TITLE) { if (process.env.SERVER_TITLE) {
+140
View File
@@ -0,0 +1,140 @@
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
forceModel,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.googleAIKey) return { object: "list", data: [] };
const googleAIVariants = ["gemini-pro"];
const models = googleAIVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "google",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const googleAIResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming Google AI response to OpenAI format");
body = transformGoogleAIResponse(body, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
function transformGoogleAIResponse(
resBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
const parts = resBody.candidates[0].content?.parts ?? [{ text: "" }];
const content = parts[0].text.replace(/^(.{0,50}?): /, () => "");
return {
id: "goo-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: { role: "assistant", content },
finish_reason: resBody.candidates[0].finishReason,
index: 0,
},
],
};
}
const googleAIProxy = createQueueMiddleware({
beforeProxy: addGoogleAIKey,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
const { protocol, hostname, path } = signedRequest;
return `${protocol}//${hostname}${path}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
error: handleProxyError,
},
}),
});
const googleAIRouter = Router();
googleAIRouter.get("/v1/models", handleModelRequest);
// OpenAI-to-Google AI compatibility endpoint.
googleAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
{ afterTransform: [forceModel("gemini-pro")] }
),
googleAIProxy
);
export const googleAI = googleAIRouter;
+55 -22
View File
@@ -2,7 +2,7 @@ import { Request, Response } from "express";
import httpProxy from "http-proxy"; import httpProxy from "http-proxy";
import { ZodError } from "zod"; import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error"; import { generateErrorMessage } from "zod-error";
import { buildFakeSse } from "../../shared/streaming"; import { makeCompletionSSE } from "../../shared/streaming";
import { assertNever } from "../../shared/utils"; import { assertNever } from "../../shared/utils";
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits"; import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
@@ -40,11 +40,13 @@ export function writeErrorResponse(
req: Request, req: Request,
res: Response, res: Response,
statusCode: number, statusCode: number,
statusMessage: string,
errorPayload: Record<string, any> errorPayload: Record<string, any>
) { ) {
const errorSource = errorPayload.error?.type?.startsWith("proxy") const msg =
? "proxy" statusCode === 500
: "upstream"; ? `The proxy encountered an error while trying to process your prompt.`
: `The proxy encountered an error while trying to send your prompt to the upstream service.`;
// If we're mid-SSE stream, send a data event with the error payload and end // If we're mid-SSE stream, send a data event with the error payload and end
// the stream. Otherwise just send a normal error response. // the stream. Otherwise just send a normal error response.
@@ -52,10 +54,15 @@ export function writeErrorResponse(
res.headersSent || res.headersSent ||
String(res.getHeader("content-type")).startsWith("text/event-stream") String(res.getHeader("content-type")).startsWith("text/event-stream")
) { ) {
const errorTitle = `${errorSource} error (${statusCode})`; const event = makeCompletionSSE({
const errorContent = JSON.stringify(errorPayload, null, 2); format: req.inboundApi,
const msg = buildFakeSse(errorTitle, errorContent, req); title: `Proxy error (HTTP ${statusCode} ${statusMessage})`,
res.write(msg); message: `${msg} Further technical details are provided below.`,
obj: errorPayload,
reqId: req.id,
model: req.body?.model,
});
res.write(event);
res.write(`data: [DONE]\n\n`); res.write(`data: [DONE]\n\n`);
res.end(); res.end();
} else { } else {
@@ -77,8 +84,9 @@ export const classifyErrorAndSend = (
res: Response res: Response
) => { ) => {
try { try {
const { status, userMessage, ...errorDetails } = classifyError(err); const { statusCode, statusMessage, userMessage, ...errorDetails } =
writeErrorResponse(req, res, status, { classifyError(err);
writeErrorResponse(req, res, statusCode, statusMessage, {
error: { message: userMessage, ...errorDetails }, error: { message: userMessage, ...errorDetails },
}); });
} catch (error) { } catch (error) {
@@ -88,14 +96,17 @@ export const classifyErrorAndSend = (
function classifyError(err: Error): { function classifyError(err: Error): {
/** HTTP status code returned to the client. */ /** HTTP status code returned to the client. */
status: number; statusCode: number;
/** HTTP status message returned to the client. */
statusMessage: string;
/** Message displayed to the user. */ /** Message displayed to the user. */
userMessage: string; userMessage: string;
/** Short error type, e.g. "proxy_validation_error". */ /** Short error type, e.g. "proxy_validation_error". */
type: string; type: string;
} & Record<string, any> { } & Record<string, any> {
const defaultError = { const defaultError = {
status: 500, statusCode: 500,
statusMessage: "Internal Server Error",
userMessage: `Reverse proxy error: ${err.message}`, userMessage: `Reverse proxy error: ${err.message}`,
type: "proxy_internal_error", type: "proxy_internal_error",
stack: err.stack, stack: err.stack,
@@ -112,19 +123,33 @@ function classifyError(err: Error): {
return `At '${rest.pathComponent}': ${issue.message}`; return `At '${rest.pathComponent}': ${issue.message}`;
}, },
}); });
return { status: 400, userMessage, type: "proxy_validation_error" }; return {
case "ForbiddenError": statusCode: 400,
statusMessage: "Bad Request",
userMessage,
type: "proxy_validation_error",
};
case "ZoomerForbiddenError":
// Mimics a ban notice from OpenAI, thrown when blockZoomerOrigins blocks // Mimics a ban notice from OpenAI, thrown when blockZoomerOrigins blocks
// a request. // a request.
return { return {
status: 403, statusCode: 403,
statusMessage: "Forbidden",
userMessage: `Your account has been disabled for violating our terms of service.`, userMessage: `Your account has been disabled for violating our terms of service.`,
type: "organization_account_disabled", type: "organization_account_disabled",
code: "policy_violation", code: "policy_violation",
}; };
case "ForbiddenError":
return {
statusCode: 403,
statusMessage: "Forbidden",
userMessage: `Request is not allowed. (${err.message})`,
type: "proxy_forbidden",
};
case "QuotaExceededError": case "QuotaExceededError":
return { return {
status: 429, statusCode: 429,
statusMessage: "Too Many Requests",
userMessage: `You've exceeded your token quota for this model type.`, userMessage: `You've exceeded your token quota for this model type.`,
type: "proxy_quota_exceeded", type: "proxy_quota_exceeded",
info: (err as QuotaExceededError).quotaInfo, info: (err as QuotaExceededError).quotaInfo,
@@ -134,21 +159,24 @@ function classifyError(err: Error): {
switch (err.code) { switch (err.code) {
case "ENOTFOUND": case "ENOTFOUND":
return { return {
status: 502, statusCode: 502,
statusMessage: "Bad Gateway",
userMessage: `Reverse proxy encountered a DNS error while trying to connect to the upstream service.`, userMessage: `Reverse proxy encountered a DNS error while trying to connect to the upstream service.`,
type: "proxy_network_error", type: "proxy_network_error",
code: err.code, code: err.code,
}; };
case "ECONNREFUSED": case "ECONNREFUSED":
return { return {
status: 502, statusCode: 502,
statusMessage: "Bad Gateway",
userMessage: `Reverse proxy couldn't connect to the upstream service.`, userMessage: `Reverse proxy couldn't connect to the upstream service.`,
type: "proxy_network_error", type: "proxy_network_error",
code: err.code, code: err.code,
}; };
case "ECONNRESET": case "ECONNRESET":
return { return {
status: 504, statusCode: 504,
statusMessage: "Gateway Timeout",
userMessage: `Reverse proxy timed out while waiting for the upstream service to respond.`, userMessage: `Reverse proxy timed out while waiting for the upstream service to respond.`,
type: "proxy_network_error", type: "proxy_network_error",
code: err.code, code: err.code,
@@ -165,6 +193,7 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
const format = req.outboundApi; const format = req.outboundApi;
switch (format) { switch (format) {
case "openai": case "openai":
case "mistral-ai":
return body.choices[0].message.content; return body.choices[0].message.content;
case "openai-text": case "openai-text":
return body.choices[0].text; return body.choices[0].text;
@@ -177,8 +206,11 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
return ""; return "";
} }
return body.completion.trim(); return body.completion.trim();
case "google-palm": case "google-ai":
return body.candidates[0].output; if ("choices" in body) {
return body.choices[0].message.content;
}
return body.candidates[0].content.parts[0].text;
case "openai-image": case "openai-image":
return body.data?.map((item: any) => item.url).join("\n"); return body.data?.map((item: any) => item.url).join("\n");
default: default:
@@ -191,13 +223,14 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
switch (format) { switch (format) {
case "openai": case "openai":
case "openai-text": case "openai-text":
case "mistral-ai":
return body.model; return body.model;
case "openai-image": case "openai-image":
return req.body.model; return req.body.model;
case "anthropic": case "anthropic":
// Anthropic confirms the model in the response, but AWS Claude doesn't. // Anthropic confirms the model in the response, but AWS Claude doesn't.
return body.model || req.body.model; return body.model || req.body.model;
case "google-palm": case "google-ai":
// Google doesn't confirm the model in the response. // Google doesn't confirm the model in the response.
return req.body.model; return req.body.model;
default: default:
@@ -29,7 +29,9 @@ export const createOnProxyReqHandler = ({
// The streaming flag must be set before any other onProxyReq handler runs, // The streaming flag must be set before any other onProxyReq handler runs,
// as it may influence the behavior of subsequent handlers. // as it may influence the behavior of subsequent handlers.
// Image generation requests can't be streamed. // Image generation requests can't be streamed.
req.isStreaming = req.body.stream === true || req.body.stream === "true"; // TODO: this flag is set in too many places
req.isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming; req.body.stream = req.isStreaming;
try { try {
@@ -31,10 +31,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
case "anthropic": case "anthropic":
assignedKey = keyPool.get("claude-v1"); assignedKey = keyPool.get("claude-v1");
break; break;
case "google-palm":
assignedKey = keyPool.get("text-bison-001");
delete req.body.stream;
break;
case "openai-text": case "openai-text":
assignedKey = keyPool.get("gpt-3.5-turbo-instruct"); assignedKey = keyPool.get("gpt-3.5-turbo-instruct");
break; break;
@@ -42,6 +38,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
throw new Error( throw new Error(
"OpenAI Chat as an API translation target is not supported" "OpenAI Chat as an API translation target is not supported"
); );
case "google-ai":
throw new Error("add-key should not be used for this model.");
case "mistral-ai":
throw new Error("Mistral AI should never be translated");
case "openai-image": case "openai-image":
assignedKey = keyPool.get("dall-e-3"); assignedKey = keyPool.get("dall-e-3");
break; break;
@@ -71,23 +71,16 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
if (key.organizationId) { if (key.organizationId) {
proxyReq.setHeader("OpenAI-Organization", key.organizationId); proxyReq.setHeader("OpenAI-Organization", key.organizationId);
} }
case "mistral-ai":
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`); proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
break; break;
case "google-palm":
const originalPath = proxyReq.path;
proxyReq.path = originalPath.replace(
/(\?.*)?$/,
`?key=${assignedKey.key}`
);
break;
case "azure": case "azure":
const azureKey = assignedKey.key; const azureKey = assignedKey.key;
proxyReq.setHeader("api-key", azureKey); proxyReq.setHeader("api-key", azureKey);
break; break;
case "aws": case "aws":
throw new Error( case "google-ai":
"add-key should not be used for AWS security credentials. Use sign-aws-request instead." throw new Error("add-key should not be used for this service.");
);
default: default:
assertNever(assignedKey.service); assertNever(assignedKey.service);
} }
@@ -2,10 +2,10 @@ import { HPMRequestCallback } from "../index";
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(","); const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
class ForbiddenError extends Error { class ZoomerForbiddenError extends Error {
constructor(message: string) { constructor(message: string) {
super(message); super(message);
this.name = "ForbiddenError"; this.name = "ZoomerForbiddenError";
} }
} }
@@ -22,7 +22,7 @@ export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => {
return; return;
} }
throw new ForbiddenError( throw new ZoomerForbiddenError(
`Your access was terminated due to violation of our policies, please check your email for more information. If you believe this is in error and would like to appeal, please contact us through our help center at help.openai.com.` `Your access was terminated due to violation of our policies, please check your email for more information. If you believe this is in error and would like to appeal, please contact us through our help center at help.openai.com.`
); );
} }
@@ -1,13 +1,14 @@
import { HPMRequestCallback } from "../index"; import { HPMRequestCallback } from "../index";
import { config } from "../../../../config"; import { config } from "../../../../config";
import { ForbiddenError } from "../../../../shared/errors";
import { getModelFamilyForRequest } from "../../../../shared/models"; import { getModelFamilyForRequest } from "../../../../shared/models";
/** /**
* Ensures the selected model family is enabled by the proxy configuration. * Ensures the selected model family is enabled by the proxy configuration.
**/ **/
export const checkModelFamily: HPMRequestCallback = (proxyReq, req) => { export const checkModelFamily: HPMRequestCallback = (_proxyReq, req, res) => {
const family = getModelFamilyForRequest(req); const family = getModelFamilyForRequest(req);
if (!config.allowedModelFamilies.includes(family)) { if (!config.allowedModelFamilies.includes(family)) {
throw new Error(`Model family ${family} is not permitted on this proxy`); throw new ForbiddenError(`Model family '${family}' is not enabled on this proxy`);
} }
}; };
@@ -1,9 +1,9 @@
import type { HPMRequestCallback } from "../index"; import type { HPMRequestCallback } from "../index";
/** /**
* For AWS/Azure requests, the body is signed earlier in the request pipeline, * For AWS/Azure/Google requests, the body is signed earlier in the request
* before the proxy middleware. This function just assigns the path and headers * pipeline, before the proxy middleware. This function just assigns the path
* to the proxy request. * and headers to the proxy request.
*/ */
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => { export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
if (!req.signedRequest) { if (!req.signedRequest) {
@@ -0,0 +1,40 @@
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
export const addGoogleAIKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
const serviceValid = req.service === "google-ai";
if (!apisValid || !serviceValid) {
throw new Error("addGoogleAIKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const model = req.body.model;
req.key = keyPool.get(model);
req.log.info(
{ key: req.key.hash, model },
"Assigned Google AI API key to request"
);
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
req.isStreaming = req.isStreaming || req.body.stream;
delete req.body.stream;
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: "generativelanguage.googleapis.com",
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
headers: {
["host"]: `generativelanguage.googleapis.com`,
["content-type"]: "application/json",
},
body: JSON.stringify(req.body),
};
};
@@ -1,7 +1,11 @@
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
import { countTokens } from "../../../../shared/tokenization"; import { countTokens } from "../../../../shared/tokenization";
import { assertNever } from "../../../../shared/utils"; import { assertNever } from "../../../../shared/utils";
import type { OpenAIChatMessage } from "./transform-outbound-payload"; import type {
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "./transform-outbound-payload";
/** /**
* Given a request with an already-transformed body, counts the number of * Given a request with an already-transformed body, counts the number of
@@ -30,9 +34,15 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service }); result = await countTokens({ req, prompt, service });
break; break;
} }
case "google-palm": { case "google-ai": {
req.outputTokens = req.body.maxOutputTokens; req.outputTokens = req.body.generationConfig.maxOutputTokens;
const prompt: string = req.body.prompt.text; const prompt: GoogleAIChatMessage[] = req.body.contents;
result = await countTokens({ req, prompt, service });
break;
}
case "mistral-ai": {
req.outputTokens = req.body.max_tokens;
const prompt: MistralAIChatMessage[] = req.body.messages;
result = await countTokens({ req, prompt, service }); result = await countTokens({ req, prompt, service });
break; break;
} }
@@ -3,7 +3,10 @@ import { config } from "../../../../config";
import { assertNever } from "../../../../shared/utils"; import { assertNever } from "../../../../shared/utils";
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
import { UserInputError } from "../../../../shared/errors"; import { UserInputError } from "../../../../shared/errors";
import { OpenAIChatMessage } from "./transform-outbound-payload"; import {
MistralAIChatMessage,
OpenAIChatMessage,
} from "./transform-outbound-payload";
const rejectedClients = new Map<string, number>(); const rejectedClients = new Map<string, number>();
@@ -53,8 +56,9 @@ function getPromptFromRequest(req: Request) {
case "anthropic": case "anthropic":
return body.prompt; return body.prompt;
case "openai": case "openai":
case "mistral-ai":
return body.messages return body.messages
.map((msg: OpenAIChatMessage) => { .map((msg: OpenAIChatMessage | MistralAIChatMessage) => {
const text = Array.isArray(msg.content) const text = Array.isArray(msg.content)
? msg.content ? msg.content
.map((c) => { .map((c) => {
@@ -68,7 +72,7 @@ function getPromptFromRequest(req: Request) {
case "openai-text": case "openai-text":
case "openai-image": case "openai-image":
return body.prompt; return body.prompt;
case "google-palm": case "google-ai":
return body.prompt.text; return body.prompt.text;
default: default:
assertNever(service); assertNever(service);
@@ -1,13 +1,14 @@
import { Request } from "express"; import { Request } from "express";
import { APIFormat, LLMService } from "../../../../shared/key-management"; import { APIFormat } from "../../../../shared/key-management";
import { LLMService } from "../../../../shared/models";
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
export const setApiFormat = (api: { export const setApiFormat = (api: {
inApi: Request["inboundApi"]; inApi: Request["inboundApi"];
outApi: APIFormat; outApi: APIFormat;
service: LLMService, service: LLMService;
}): RequestPreprocessor => { }): RequestPreprocessor => {
return function configureRequestApiFormat (req) { return function configureRequestApiFormat(req) {
req.inboundApi = api.inApi; req.inboundApi = api.inApi;
req.outboundApi = api.outApi; req.outboundApi = api.outApi;
req.service = api.service; req.service = api.service;
@@ -32,7 +32,7 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
temperature: true, temperature: true,
top_k: true, top_k: true,
top_p: true, top_p: true,
}).parse(req.body); }).strip().parse(req.body);
const credential = getCredentialParts(req); const credential = getCredentialParts(req);
const host = AMZ_HOST.replace("%REGION%", credential.region); const host = AMZ_HOST.replace("%REGION%", credential.region);
@@ -1,7 +1,10 @@
import { Request } from "express"; import { Request } from "express";
import { z } from "zod"; import { z } from "zod";
import { config } from "../../../../config"; import { config } from "../../../../config";
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common"; import {
isTextGenerationRequest,
isImageGenerationRequest,
} from "../../common";
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
import { APIFormat } from "../../../../shared/key-management"; import { APIFormat } from "../../../../shared/key-management";
@@ -11,23 +14,24 @@ const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
// TODO: move schemas to shared // TODO: move schemas to shared
// https://console.anthropic.com/docs/api/reference#-v1-complete // https://console.anthropic.com/docs/api/reference#-v1-complete
export const AnthropicV1CompleteSchema = z.object({ export const AnthropicV1CompleteSchema = z
model: z.string(), .object({
prompt: z.string({ model: z.string().max(100),
required_error: prompt: z.string({
"No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?", required_error:
}), "No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?",
max_tokens_to_sample: z.coerce }),
.number() max_tokens_to_sample: z.coerce
.int() .number()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)), .int()
stop_sequences: z.array(z.string()).optional(), .transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
stream: z.boolean().optional().default(false), stop_sequences: z.array(z.string().max(500)).optional(),
temperature: z.coerce.number().optional().default(1), stream: z.boolean().optional().default(false),
top_k: z.coerce.number().optional(), temperature: z.coerce.number().optional().default(1),
top_p: z.coerce.number().optional(), top_k: z.coerce.number().optional(),
metadata: z.any().optional(), top_p: z.coerce.number().optional(),
}); })
.strip();
// https://platform.openai.com/docs/api-reference/chat/create // https://platform.openai.com/docs/api-reference/chat/create
const OpenAIV1ChatContentArraySchema = z.array( const OpenAIV1ChatContentArraySchema = z.array(
@@ -43,44 +47,48 @@ const OpenAIV1ChatContentArraySchema = z.array(
]) ])
); );
export const OpenAIV1ChatCompletionSchema = z.object({ export const OpenAIV1ChatCompletionSchema = z
model: z.string(), .object({
messages: z.array( model: z.string().max(100),
z.object({ messages: z.array(
role: z.enum(["system", "user", "assistant"]), z.object({
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]), role: z.enum(["system", "user", "assistant"]),
name: z.string().optional(), content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
}), name: z.string().optional(),
{
required_error:
"No `messages` found. Ensure you've set the correct completion endpoint.",
invalid_type_error:
"Messages were not formatted correctly. Refer to the OpenAI Chat API documentation for more information.",
}
),
temperature: z.number().optional().default(1),
top_p: z.number().optional().default(1),
n: z
.literal(1, {
errorMap: () => ({
message: "You may only request a single completion at a time.",
}), }),
}) {
.optional(), required_error:
stream: z.boolean().optional().default(false), "No `messages` found. Ensure you've set the correct completion endpoint.",
stop: z.union([z.string(), z.array(z.string())]).optional(), invalid_type_error:
max_tokens: z.coerce "Messages were not formatted correctly. Refer to the OpenAI Chat API documentation for more information.",
.number() }
.int() ),
.nullish() temperature: z.number().optional().default(1),
.default(16) top_p: z.number().optional().default(1),
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)), n: z
frequency_penalty: z.number().optional().default(0), .literal(1, {
presence_penalty: z.number().optional().default(0), errorMap: () => ({
logit_bias: z.any().optional(), message: "You may only request a single completion at a time.",
user: z.string().optional(), }),
seed: z.number().int().optional(), })
}); .optional(),
stream: z.boolean().optional().default(false),
stop: z
.union([z.string().max(500), z.array(z.string().max(500))])
.optional(),
max_tokens: z.coerce
.number()
.int()
.nullish()
.default(16)
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
logit_bias: z.any().optional(),
user: z.string().max(500).optional(),
seed: z.number().int().optional(),
})
.strip();
export type OpenAIChatMessage = z.infer< export type OpenAIChatMessage = z.infer<
typeof OpenAIV1ChatCompletionSchema typeof OpenAIV1ChatCompletionSchema
@@ -90,6 +98,7 @@ const OpenAIV1TextCompletionSchema = z
.object({ .object({
model: z model: z
.string() .string()
.max(100)
.regex( .regex(
/^gpt-3.5-turbo-instruct/, /^gpt-3.5-turbo-instruct/,
"Model must start with 'gpt-3.5-turbo-instruct'" "Model must start with 'gpt-3.5-turbo-instruct'"
@@ -101,50 +110,96 @@ const OpenAIV1TextCompletionSchema = z
logprobs: z.number().int().nullish().default(null), logprobs: z.number().int().nullish().default(null),
echo: z.boolean().optional().default(false), echo: z.boolean().optional().default(false),
best_of: z.literal(1).optional(), best_of: z.literal(1).optional(),
stop: z.union([z.string(), z.array(z.string()).max(4)]).optional(), stop: z
suffix: z.string().optional(), .union([z.string().max(500), z.array(z.string().max(500)).max(4)])
.optional(),
suffix: z.string().max(1000).optional(),
}) })
.strip()
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true })); .merge(OpenAIV1ChatCompletionSchema.omit({ messages: true }));
// https://platform.openai.com/docs/api-reference/images/create // https://platform.openai.com/docs/api-reference/images/create
const OpenAIV1ImagesGenerationSchema = z.object({ const OpenAIV1ImagesGenerationSchema = z
prompt: z.string().max(4000), .object({
model: z.string().optional(), prompt: z.string().max(4000),
quality: z.enum(["standard", "hd"]).optional().default("standard"), model: z.string().max(100).optional(),
n: z.number().int().min(1).max(4).optional().default(1), quality: z.enum(["standard", "hd"]).optional().default("standard"),
response_format: z.enum(["url", "b64_json"]).optional(), n: z.number().int().min(1).max(4).optional().default(1),
size: z response_format: z.enum(["url", "b64_json"]).optional(),
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]) size: z
.optional() .enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
.default("1024x1024"), .optional()
style: z.enum(["vivid", "natural"]).optional().default("vivid"), .default("1024x1024"),
user: z.string().optional(), style: z.enum(["vivid", "natural"]).optional().default("vivid"),
}); user: z.string().max(500).optional(),
})
.strip();
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText // https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
const PalmV1GenerateTextSchema = z.object({ const GoogleAIV1GenerateContentSchema = z
.object({
model: z.string().max(100), //actually specified in path but we need it for the router
stream: z.boolean().optional().default(false), // also used for router
contents: z.array(
z.object({
parts: z.array(z.object({ text: z.string() })),
role: z.enum(["user", "model"]),
})
),
tools: z.array(z.object({})).max(0).optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
generationConfig: z.object({
temperature: z.number().optional(),
maxOutputTokens: z.coerce
.number()
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, 1024)), // TODO: Add config
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
stopSequences: z.array(z.string().max(500)).max(5).optional(),
}),
})
.strip();
export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema
>["contents"][0];
// https://docs.mistral.ai/api#operation/createChatCompletion
const MistralAIV1ChatCompletionsSchema = z.object({
model: z.string(), model: z.string(),
prompt: z.object({ text: z.string() }), messages: z.array(
temperature: z.number().optional(), z.object({
maxOutputTokens: z.coerce role: z.enum(["system", "user", "assistant"]),
content: z.string(),
})
),
temperature: z.number().optional().default(0.7),
top_p: z.number().optional().default(1),
max_tokens: z.coerce
.number() .number()
.int() .int()
.optional() .nullish()
.default(16) .transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
.transform((v) => Math.min(v, 1024)), // TODO: Add config stream: z.boolean().optional().default(false),
candidateCount: z.literal(1).optional(), safe_mode: z.boolean().optional().default(false),
topP: z.number().optional(), random_seed: z.number().int().optional(),
topK: z.number().optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
stopSequences: z.array(z.string()).max(5).optional(),
}); });
export type MistralAIChatMessage = z.infer<
typeof MistralAIV1ChatCompletionsSchema
>["messages"][0];
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = { const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
anthropic: AnthropicV1CompleteSchema, anthropic: AnthropicV1CompleteSchema,
openai: OpenAIV1ChatCompletionSchema, openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema, "openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema, "openai-image": OpenAIV1ImagesGenerationSchema,
"google-palm": PalmV1GenerateTextSchema, "google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
}; };
/** Transforms an incoming request body to one that matches the target API. */ /** Transforms an incoming request body to one that matches the target API. */
@@ -174,8 +229,8 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
return; return;
} }
if (req.inboundApi === "openai" && req.outboundApi === "google-palm") { if (req.inboundApi === "openai" && req.outboundApi === "google-ai") {
req.body = openaiToPalm(req); req.body = openaiToGoogleAI(req);
return; return;
} }
@@ -310,7 +365,9 @@ function openaiToOpenaiImage(req: Request) {
return OpenAIV1ImagesGenerationSchema.parse(transformed); return OpenAIV1ImagesGenerationSchema.parse(transformed);
} }
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> { function openaiToGoogleAI(
req: Request
): z.infer<typeof GoogleAIV1GenerateContentSchema> {
const { body } = req; const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({ const result = OpenAIV1ChatCompletionSchema.safeParse({
...body, ...body,
@@ -319,40 +376,77 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
if (!result.success) { if (!result.success) {
req.log.warn( req.log.warn(
{ issues: result.error.issues, body }, { issues: result.error.issues, body },
"Invalid OpenAI-to-Palm request" "Invalid OpenAI-to-Google AI request"
); );
throw result.error; throw result.error;
} }
const { messages, ...rest } = result.data; const { messages, ...rest } = result.data;
const prompt = flattenOpenAIChatMessages(messages); const foundNames = new Set<string>();
const contents = messages
.map((m) => {
const role = m.role === "assistant" ? "model" : "user";
// Detects character names so we can set stop sequences for them as Gemini
// is prone to continuing as the next character.
// If names are not available, we'll still try to prefix the message
// with generic names so we can set stops for them but they don't work
// as well as real names.
const text = flattenOpenAIMessageContent(m.content);
const propName = m.name?.trim();
const textName =
m.role === "system" ? "" : text.match(/^(.{0,50}?): /)?.[1]?.trim();
const name =
propName || textName || (role === "model" ? "Character" : "User");
foundNames.add(name);
// Prefixing messages with their character name seems to help avoid
// Gemini trying to continue as the next character, or at the very least
// ensures it will hit the stop sequence. Otherwise it will start a new
// paragraph and switch perspectives.
// The response will be very likely to include this prefix so frontends
// will need to strip it out.
const textPrefix = textName ? "" : `${name}: `;
return {
parts: [{ text: textPrefix + text }],
role: m.role === "assistant" ? ("model" as const) : ("user" as const),
};
})
.reduce<GoogleAIChatMessage[]>((acc, msg) => {
const last = acc[acc.length - 1];
if (last?.role === msg.role) {
last.parts[0].text += "\n\n" + msg.parts[0].text;
} else {
acc.push(msg);
}
return acc;
}, []);
let stops = rest.stop let stops = rest.stop
? Array.isArray(rest.stop) ? Array.isArray(rest.stop)
? rest.stop ? rest.stop
: [rest.stop] : [rest.stop]
: []; : [];
stops.push(...Array.from(foundNames).map((name) => `\n${name}:`));
stops.push("\n\nUser:"); stops = [...new Set(stops)].slice(0, 5);
stops = [...new Set(stops)];
z.array(z.string()).max(5).parse(stops);
return { return {
prompt: { text: prompt }, model: "gemini-pro",
maxOutputTokens: rest.max_tokens, stream: rest.stream,
stopSequences: stops, contents,
model: "text-bison-001", tools: [],
topP: rest.top_p, generationConfig: {
temperature: rest.temperature, maxOutputTokens: rest.max_tokens,
stopSequences: stops,
topP: rest.top_p,
topK: 40, // openai schema doesn't have this, google ai defaults to 40
temperature: rest.temperature,
},
safetySettings: [ safetySettings: [
{ category: "HARM_CATEGORY_UNSPECIFIED", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DEROGATORY", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_TOXICITY", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_VIOLENCE", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_SEXUAL", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_MEDICAL", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DANGEROUS", threshold: "BLOCK_NONE" },
], ],
}; };
} }
@@ -6,7 +6,8 @@ import { RequestPreprocessor } from "../index";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic; const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI; const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
const BISON_MAX_CONTEXT = 8100; const GOOGLE_AI_MAX_CONTEXT = 32000;
const MISTRAL_AI_MAX_CONTENT = 32768;
/** /**
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body * Assigns `req.promptTokens` and `req.outputTokens` based on the request body
@@ -31,9 +32,11 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
case "anthropic": case "anthropic":
proxyMax = CLAUDE_MAX_CONTEXT; proxyMax = CLAUDE_MAX_CONTEXT;
break; break;
case "google-palm": case "google-ai":
proxyMax = BISON_MAX_CONTEXT; proxyMax = GOOGLE_AI_MAX_CONTEXT;
break; break;
case "mistral-ai":
proxyMax = MISTRAL_AI_MAX_CONTENT;
case "openai-image": case "openai-image":
return; return;
default: default:
@@ -62,8 +65,10 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 100000; modelMax = 100000;
} else if (model.match(/^claude-2/)) { } else if (model.match(/^claude-2/)) {
modelMax = 200000; modelMax = 200000;
} else if (model.match(/^text-bison-\d{3}$/)) { } else if (model.match(/^gemini-\d{3}$/)) {
modelMax = BISON_MAX_CONTEXT; modelMax = GOOGLE_AI_MAX_CONTEXT;
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
modelMax = MISTRAL_AI_MAX_CONTENT;
} else if (model.match(/^anthropic\.claude/)) { } else if (model.match(/^anthropic\.claude/)) {
// Not sure if AWS Claude has the same context limit as Anthropic Claude. // Not sure if AWS Claude has the same context limit as Anthropic Claude.
modelMax = 100000; modelMax = 100000;
@@ -1,8 +1,7 @@
import express from "express";
import { pipeline } from "stream"; import { pipeline } from "stream";
import { promisify } from "util"; import { promisify } from "util";
import { import {
buildFakeSse, makeCompletionSSE,
copySseResponseHeaders, copySseResponseHeaders,
initializeSseStream, initializeSseStream,
} from "../../../shared/streaming"; } from "../../../shared/streaming";
@@ -59,7 +58,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
const prefersNativeEvents = req.inboundApi === req.outboundApi; const prefersNativeEvents = req.inboundApi === req.outboundApi;
const contentType = proxyRes.headers["content-type"]; const contentType = proxyRes.headers["content-type"];
const adapter = new SSEStreamAdapter({ contentType }); const adapter = new SSEStreamAdapter({ contentType, api: req.outboundApi });
const aggregator = new EventAggregator({ format: req.outboundApi }); const aggregator = new EventAggregator({ format: req.outboundApi });
const transformer = new SSEMessageTransformer({ const transformer = new SSEMessageTransformer({
inputFormat: req.outboundApi, inputFormat: req.outboundApi,
@@ -89,10 +88,20 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
`Re-enqueueing request due to retryable error during streaming response.` `Re-enqueueing request due to retryable error during streaming response.`
); );
req.retryCount++; req.retryCount++;
enqueue(req); await enqueue(req);
} else { } else {
const errorEvent = buildFakeSse("stream-error", err.message, req); const { message, stack, lastEvent } = err;
res.write(`${errorEvent}data: [DONE]\n\n`); const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined"
const errorEvent = makeCompletionSSE({
format: req.inboundApi,
title: "Proxy stream error",
message: "An unexpected error occurred while streaming the response.",
obj: { message, stack, lastEvent: eventText },
reqId: req.id,
model: req.body?.model,
});
res.write(errorEvent);
res.write(`data: [DONE]\n\n`);
res.end(); res.end();
} }
throw err; throw err;
+112 -33
View File
@@ -152,13 +152,13 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
}; };
}; };
function reenqueueRequest(req: Request) { async function reenqueueRequest(req: Request) {
req.log.info( req.log.info(
{ key: req.key?.hash, retryCount: req.retryCount }, { key: req.key?.hash, retryCount: req.retryCount },
`Re-enqueueing request due to retryable error` `Re-enqueueing request due to retryable error`
); );
req.retryCount++; req.retryCount++;
enqueue(req); await enqueue(req);
} }
/** /**
@@ -192,7 +192,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
} else { } else {
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`; const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, errorMessage); req.log.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
writeErrorResponse(req, res, 500, { writeErrorResponse(req, res, 500, "Internal Server Error", {
error: errorMessage, error: errorMessage,
contentEncoding, contentEncoding,
}); });
@@ -209,7 +209,9 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
} catch (error: any) { } catch (error: any) {
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`; const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
req.log.warn({ error: error.stack, key: req.key?.hash }, errorMessage); req.log.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
writeErrorResponse(req, res, 500, { error: errorMessage }); writeErrorResponse(req, res, 500, "Internal Server Error", {
error: errorMessage,
});
return reject(errorMessage); return reject(errorMessage);
} }
}); });
@@ -237,6 +239,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
body body
) => { ) => {
const statusCode = proxyRes.statusCode || 500; const statusCode = proxyRes.statusCode || 500;
const statusMessage = proxyRes.statusMessage || "Internal Server Error";
if (statusCode < 400) { if (statusCode < 400) {
return; return;
@@ -253,16 +256,16 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} catch (parseError) { } catch (parseError) {
// Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy // Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy
const hash = req.key?.hash; const hash = req.key?.hash;
const statusMessage = proxyRes.statusMessage || "Unknown error";
req.log.warn({ statusCode, statusMessage, key: hash }, parseError.message); req.log.warn({ statusCode, statusMessage, key: hash }, parseError.message);
const errorObject = { const errorObject = {
statusCode,
statusMessage: proxyRes.statusMessage,
error: parseError.message, error: parseError.message,
proxy_note: `This is likely a temporary error with the upstream service.`, status: statusCode,
statusMessage,
proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service.`,
}; };
writeErrorResponse(req, res, statusCode, errorObject);
writeErrorResponse(req, res, statusCode, statusMessage, errorObject);
throw new HttpError(statusCode, parseError.message); throw new HttpError(statusCode, parseError.message);
} }
@@ -288,7 +291,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
// For Anthropic, this is usually due to missing preamble. // For Anthropic, this is usually due to missing preamble.
switch (service) { switch (service) {
case "openai": case "openai":
case "google-palm": case "google-ai":
case "mistral-ai":
case "azure": case "azure":
const filteredCodes = ["content_policy_violation", "content_filter"]; const filteredCodes = ["content_policy_violation", "content_filter"];
if (filteredCodes.includes(errorPayload.error?.code)) { if (filteredCodes.includes(errorPayload.error?.code)) {
@@ -297,14 +301,14 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (errorPayload.error?.code === "billing_hard_limit_reached") { } else if (errorPayload.error?.code === "billing_hard_limit_reached") {
// For some reason, some models return this 400 error instead of the // For some reason, some models return this 400 error instead of the
// same 429 billing error that other models return. // same 429 billing error that other models return.
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload); await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
} else { } else {
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`; errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
} }
break; break;
case "anthropic": case "anthropic":
case "aws": case "aws":
maybeHandleMissingPreambleError(req, errorPayload); await maybeHandleMissingPreambleError(req, errorPayload);
break; break;
default: default:
assertNever(service); assertNever(service);
@@ -314,7 +318,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
keyPool.disable(req.key!, "revoked"); keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`; errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
} else if (statusCode === 403) { } else if (statusCode === 403) {
// Amazon is the only service that returns 403. if (service === "anthropic") {
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
return;
}
switch (errorType) { switch (errorType) {
case "UnrecognizedClientException": case "UnrecognizedClientException":
// Key is invalid. // Key is invalid.
@@ -335,19 +343,20 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (statusCode === 429) { } else if (statusCode === 429) {
switch (service) { switch (service) {
case "openai": case "openai":
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload); await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
break; break;
case "anthropic": case "anthropic":
handleAnthropicRateLimitError(req, errorPayload); await handleAnthropicRateLimitError(req, errorPayload);
break; break;
case "aws": case "aws":
handleAwsRateLimitError(req, errorPayload); await handleAwsRateLimitError(req, errorPayload);
break; break;
case "azure": case "azure":
handleAzureRateLimitError(req, errorPayload); case "mistral-ai":
await handleAzureRateLimitError(req, errorPayload);
break; break;
case "google-palm": case "google-ai":
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`; await handleGoogleAIRateLimitError(req, errorPayload);
break; break;
default: default:
assertNever(service); assertNever(service);
@@ -369,8 +378,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "anthropic": case "anthropic":
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`; errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
break; break;
case "google-palm": case "google-ai":
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`; errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
break;
case "mistral-ai":
errorPayload.proxy_note = `The requested Mistral AI model might not exist, or the key might not be provisioned for it.`;
break; break;
case "aws": case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`; errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
@@ -393,7 +405,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
); );
} }
writeErrorResponse(req, res, statusCode, errorPayload); writeErrorResponse(req, res, statusCode, statusMessage, errorPayload);
throw new HttpError(statusCode, errorPayload.error?.message); throw new HttpError(statusCode, errorPayload.error?.message);
}; };
@@ -416,7 +428,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
* } * }
* ``` * ```
*/ */
function maybeHandleMissingPreambleError( async function maybeHandleMissingPreambleError(
req: Request, req: Request,
errorPayload: ProxiedErrorPayload errorPayload: ProxiedErrorPayload
) { ) {
@@ -429,27 +441,27 @@ function maybeHandleMissingPreambleError(
"Request failed due to missing preamble. Key will be marked as such for subsequent requests." "Request failed due to missing preamble. Key will be marked as such for subsequent requests."
); );
keyPool.update(req.key!, { requiresPreamble: true }); keyPool.update(req.key!, { requiresPreamble: true });
reenqueueRequest(req); await reenqueueRequest(req);
throw new RetryableError("Claude request re-enqueued to add preamble."); throw new RetryableError("Claude request re-enqueued to add preamble.");
} else { } else {
errorPayload.proxy_note = `Proxy received unrecognized error from Anthropic. Check the specific error for more information.`; errorPayload.proxy_note = `Proxy received unrecognized error from Anthropic. Check the specific error for more information.`;
} }
} }
function handleAnthropicRateLimitError( async function handleAnthropicRateLimitError(
req: Request, req: Request,
errorPayload: ProxiedErrorPayload errorPayload: ProxiedErrorPayload
) { ) {
if (errorPayload.error?.type === "rate_limit_error") { if (errorPayload.error?.type === "rate_limit_error") {
keyPool.markRateLimited(req.key!); keyPool.markRateLimited(req.key!);
reenqueueRequest(req); await reenqueueRequest(req);
throw new RetryableError("Claude rate-limited request re-enqueued."); throw new RetryableError("Claude rate-limited request re-enqueued.");
} else { } else {
errorPayload.proxy_note = `Unrecognized rate limit error from Anthropic. Key may be over quota.`; errorPayload.proxy_note = `Unrecognized rate limit error from Anthropic. Key may be over quota.`;
} }
} }
function handleAwsRateLimitError( async function handleAwsRateLimitError(
req: Request, req: Request,
errorPayload: ProxiedErrorPayload errorPayload: ProxiedErrorPayload
) { ) {
@@ -457,7 +469,7 @@ function handleAwsRateLimitError(
switch (errorType) { switch (errorType) {
case "ThrottlingException": case "ThrottlingException":
keyPool.markRateLimited(req.key!); keyPool.markRateLimited(req.key!);
reenqueueRequest(req); await reenqueueRequest(req);
throw new RetryableError("AWS rate-limited request re-enqueued."); throw new RetryableError("AWS rate-limited request re-enqueued.");
case "ModelNotReadyException": case "ModelNotReadyException":
errorPayload.proxy_note = `The requested model is overloaded. Try again in a few seconds.`; errorPayload.proxy_note = `The requested model is overloaded. Try again in a few seconds.`;
@@ -467,11 +479,11 @@ function handleAwsRateLimitError(
} }
} }
function handleOpenAIRateLimitError( async function handleOpenAIRateLimitError(
req: Request, req: Request,
tryAgainMessage: string, tryAgainMessage: string,
errorPayload: ProxiedErrorPayload errorPayload: ProxiedErrorPayload
): Record<string, any> { ): Promise<Record<string, any>> {
const type = errorPayload.error?.type; const type = errorPayload.error?.type;
switch (type) { switch (type) {
case "insufficient_quota": case "insufficient_quota":
@@ -500,8 +512,58 @@ function handleOpenAIRateLimitError(
} }
// Per-minute request or token rate limit is exceeded, which we can retry // Per-minute request or token rate limit is exceeded, which we can retry
reenqueueRequest(req); await reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued."); throw new RetryableError("Rate-limited request re-enqueued.");
// WIP/nonfunctional
// case "tokens_usage_based":
// // Weird new rate limit type that seems limited to preview models.
// // Distinct from `tokens` type. Can be per-minute or per-day.
//
// // I've seen reports of this error for 500k tokens/day and 10k tokens/min.
// // 10k tokens per minute is problematic, because this is much less than
// // GPT4-Turbo's max context size for a single prompt and is effectively a
// // cap on the max context size for just that key+model, which the app is
// // not able to deal with.
//
// // Similarly if there is a 500k tokens per day limit and 450k tokens have
// // been used today, the max context for that key becomes 50k tokens until
// // the next day and becomes progressively smaller as more tokens are used.
//
// // To work around these keys we will first retry the request a few times.
// // After that we will reject the request, and if it's a per-day limit we
// // will also disable the key.
//
// // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per day: Limit 500000, Used 460000, Requested 50000"
// // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per min: Limit 10000, Requested 40000"
//
// const regex =
// /Rate limit reached for .+ in organization .+ on \w+ per (day|min): Limit (\d+)(?:, Used (\d+))?, Requested (\d+)/;
// const [, period, limit, used, requested] =
// errorPayload.error?.message?.match(regex) || [];
//
// req.log.warn(
// { key: req.key?.hash, period, limit, used, requested },
// "Received `tokens_usage_based` rate limit error from OpenAI."
// );
//
// if (!period || !limit || !requested) {
// errorPayload.proxy_note = `Unrecognized rate limit error from OpenAI. (${errorPayload.error?.message})`;
// break;
// }
//
// if (req.retryCount < 2) {
// await reenqueueRequest(req);
// throw new RetryableError("Rate-limited request re-enqueued.");
// }
//
// if (period === "min") {
// errorPayload.proxy_note = `Assigned key can't be used for prompts longer than ${limit} tokens, and no other keys are available right now. Reduce the length of your prompt or try again in a few minutes.`;
// } else {
// errorPayload.proxy_note = `Assigned key has reached its per-day request limit for this model. Try another model.`;
// }
//
// keyPool.markRateLimited(req.key!);
// break;
default: default:
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`; errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
break; break;
@@ -509,7 +571,7 @@ function handleOpenAIRateLimitError(
return errorPayload; return errorPayload;
} }
function handleAzureRateLimitError( async function handleAzureRateLimitError(
req: Request, req: Request,
errorPayload: ProxiedErrorPayload errorPayload: ProxiedErrorPayload
) { ) {
@@ -517,7 +579,7 @@ function handleAzureRateLimitError(
switch (code) { switch (code) {
case "429": case "429":
keyPool.markRateLimited(req.key!); keyPool.markRateLimited(req.key!);
reenqueueRequest(req); await reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued."); throw new RetryableError("Rate-limited request re-enqueued.");
default: default:
errorPayload.proxy_note = `Unrecognized rate limit error from Azure (${code}). Please report this.`; errorPayload.proxy_note = `Unrecognized rate limit error from Azure (${code}). Please report this.`;
@@ -525,6 +587,23 @@ function handleAzureRateLimitError(
} }
} }
//{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}
async function handleGoogleAIRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
const status = errorPayload.error?.status;
switch (status) {
case "RESOURCE_EXHAUSTED":
keyPool.markRateLimited(req.key!);
await reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
default:
errorPayload.proxy_note = `Unrecognized rate limit error from Google AI (${status}). Please report this.`;
break;
}
}
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => { const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) { if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
const model = req.body.model; const model = req.body.model;
+8 -4
View File
@@ -9,7 +9,10 @@ import {
} from "../common"; } from "../common";
import { ProxyResHandlerWithBody } from "."; import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils"; import { assertNever } from "../../../shared/utils";
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload"; import {
MistralAIChatMessage,
OpenAIChatMessage,
} from "../request/preprocessors/transform-outbound-payload";
/** If prompt logging is enabled, enqueues the prompt for logging. */ /** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async ( export const logPrompt: ProxyResHandlerWithBody = async (
@@ -54,12 +57,13 @@ type OaiImageResult = {
const getPromptForRequest = ( const getPromptForRequest = (
req: Request, req: Request,
responseBody: Record<string, any> responseBody: Record<string, any>
): string | OpenAIChatMessage[] | OaiImageResult => { ): string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult => {
// Since the prompt logger only runs after the request has been proxied, we // Since the prompt logger only runs after the request has been proxied, we
// can assume the body has already been transformed to the target API's // can assume the body has already been transformed to the target API's
// format. // format.
switch (req.outboundApi) { switch (req.outboundApi) {
case "openai": case "openai":
case "mistral-ai":
return req.body.messages; return req.body.messages;
case "openai-text": case "openai-text":
return req.body.prompt; return req.body.prompt;
@@ -73,7 +77,7 @@ const getPromptForRequest = (
}; };
case "anthropic": case "anthropic":
return req.body.prompt; return req.body.prompt;
case "google-palm": case "google-ai":
return req.body.prompt.text; return req.body.prompt.text;
default: default:
assertNever(req.outboundApi); assertNever(req.outboundApi);
@@ -81,7 +85,7 @@ const getPromptForRequest = (
}; };
const flattenMessages = ( const flattenMessages = (
val: string | OpenAIChatMessage[] | OaiImageResult val: string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult
): string => { ): string => {
if (typeof val === "string") { if (typeof val === "string") {
return val.trim(); return val.trim();
@@ -4,7 +4,7 @@ import {
mergeEventsForAnthropic, mergeEventsForAnthropic,
mergeEventsForOpenAIChat, mergeEventsForOpenAIChat,
mergeEventsForOpenAIText, mergeEventsForOpenAIText,
OpenAIChatCompletionStreamEvent OpenAIChatCompletionStreamEvent,
} from "./index"; } from "./index";
/** /**
@@ -27,12 +27,13 @@ export class EventAggregator {
getFinalResponse() { getFinalResponse() {
switch (this.format) { switch (this.format) {
case "openai": case "openai":
case "google-ai":
case "mistral-ai":
return mergeEventsForOpenAIChat(this.events); return mergeEventsForOpenAIChat(this.events);
case "openai-text": case "openai-text":
return mergeEventsForOpenAIText(this.events); return mergeEventsForOpenAIText(this.events);
case "anthropic": case "anthropic":
return mergeEventsForAnthropic(this.events); return mergeEventsForAnthropic(this.events);
case "google-palm":
case "openai-image": case "openai-image":
throw new Error(`SSE aggregation not supported for ${this.format}`); throw new Error(`SSE aggregation not supported for ${this.format}`);
default: default:
@@ -25,6 +25,8 @@ export type StreamingCompletionTransformer = (
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai"; export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai"; export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai"; export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat"; export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text"; export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropic } from "./aggregators/anthropic"; export { mergeEventsForAnthropic } from "./aggregators/anthropic";
@@ -7,9 +7,10 @@ import {
anthropicV2ToOpenAI, anthropicV2ToOpenAI,
OpenAIChatCompletionStreamEvent, OpenAIChatCompletionStreamEvent,
openAITextToOpenAIChat, openAITextToOpenAIChat,
googleAIToOpenAI,
passthroughToOpenAI,
StreamingCompletionTransformer, StreamingCompletionTransformer,
} from "./index"; } from "./index";
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
const genlog = logger.child({ module: "sse-transformer" }); const genlog = logger.child({ module: "sse-transformer" });
@@ -92,6 +93,7 @@ export class SSEMessageTransformer extends Transform {
this.push(transformedMessage); this.push(transformedMessage);
callback(); callback();
} catch (err) { } catch (err) {
err.lastEvent = chunk?.toString();
this.log.error(err, "Error transforming SSE message"); this.log.error(err, "Error transforming SSE message");
callback(err); callback(err);
} }
@@ -104,6 +106,7 @@ function getTransformer(
): StreamingCompletionTransformer { ): StreamingCompletionTransformer {
switch (responseApi) { switch (responseApi) {
case "openai": case "openai":
case "mistral-ai":
return passthroughToOpenAI; return passthroughToOpenAI;
case "openai-text": case "openai-text":
return openAITextToOpenAIChat; return openAITextToOpenAIChat;
@@ -111,7 +114,8 @@ function getTransformer(
return version === "2023-01-01" return version === "2023-01-01"
? anthropicV1ToOpenAI ? anthropicV1ToOpenAI
: anthropicV2ToOpenAI; : anthropicV2ToOpenAI;
case "google-palm": case "google-ai":
return googleAIToOpenAI;
case "openai-image": case "openai-image":
throw new Error(`SSE transformation not supported for ${responseApi}`); throw new Error(`SSE transformation not supported for ${responseApi}`);
default: default:
@@ -1,12 +1,20 @@
import { Transform, TransformOptions } from "stream"; import { Transform, TransformOptions } from "stream";
import { StringDecoder } from "string_decoder";
// @ts-ignore // @ts-ignore
import { Parser } from "lifion-aws-event-stream"; import { Parser } from "lifion-aws-event-stream";
import { logger } from "../../../../logger"; import { logger } from "../../../../logger";
import { RetryableError } from "../index"; import { RetryableError } from "../index";
import { APIFormat } from "../../../../shared/key-management";
import StreamArray from "stream-json/streamers/StreamArray";
import { makeCompletionSSE } from "../../../../shared/streaming";
const log = logger.child({ module: "sse-stream-adapter" }); const log = logger.child({ module: "sse-stream-adapter" });
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string }; type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;
api: APIFormat;
};
type AwsEventStreamMessage = { type AwsEventStreamMessage = {
headers: { headers: {
":message-type": "event" | "exception"; ":message-type": "event" | "exception";
@@ -21,20 +29,31 @@ type AwsEventStreamMessage = {
*/ */
export class SSEStreamAdapter extends Transform { export class SSEStreamAdapter extends Transform {
private readonly isAwsStream; private readonly isAwsStream;
private parser = new Parser(); private readonly isGoogleStream;
private awsParser = new Parser();
private jsonParser = StreamArray.withParser();
private partialMessage = ""; private partialMessage = "";
private decoder = new StringDecoder("utf8");
constructor(options?: SSEStreamAdapterOptions) { constructor(options?: SSEStreamAdapterOptions) {
super(options); super(options);
this.isAwsStream = this.isAwsStream =
options?.contentType === "application/vnd.amazon.eventstream"; options?.contentType === "application/vnd.amazon.eventstream";
this.isGoogleStream = options?.api === "google-ai";
this.parser.on("data", (data: AwsEventStreamMessage) => { this.awsParser.on("data", (data: AwsEventStreamMessage) => {
const message = this.processAwsEvent(data); const message = this.processAwsEvent(data);
if (message) { if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8"); this.push(Buffer.from(message + "\n\n"), "utf8");
} }
}); });
this.jsonParser.on("data", (data: { value: any }) => {
const message = this.processGoogleValue(data.value);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
} }
protected processAwsEvent(event: AwsEventStreamMessage): string | null { protected processAwsEvent(event: AwsEventStreamMessage): string | null {
@@ -53,11 +72,16 @@ export class SSEStreamAdapter extends Transform {
); );
throw new RetryableError("AWS request throttled mid-stream"); throw new RetryableError("AWS request throttled mid-stream");
} else { } else {
log.error( log.error({ event: eventStr }, "Received bad AWS stream event");
{ event: eventStr }, return makeCompletionSSE({
"Received unexpected AWS event stream message" format: "anthropic",
); title: "Proxy stream error",
return getFakeErrorCompletion("proxy AWS error", eventStr); message:
"The proxy received malformed or unexpected data from AWS while streaming.",
obj: event,
reqId: "proxy-sse-adapter-message",
model: "",
});
} }
} else { } else {
const { bytes } = payload; const { bytes } = payload;
@@ -71,44 +95,62 @@ export class SSEStreamAdapter extends Transform {
} }
} }
// Google doesn't use event streams and just sends elements in an array over
// a long-lived HTTP connection. Needs stream-json to parse the array.
protected processGoogleValue(value: any): string | null {
try {
const candidates = value.candidates ?? [{}];
const hasParts = candidates[0].content?.parts?.length > 0;
if (hasParts) {
return `data: ${JSON.stringify(value)}`;
} else {
log.error({ event: value }, "Received bad Google AI event");
return `data: ${makeCompletionSSE({
format: "google-ai",
title: "Proxy stream error",
message:
"The proxy received malformed or unexpected data from Google AI while streaming.",
obj: value,
reqId: "proxy-sse-adapter-message",
model: "",
})}`;
}
} catch (error) {
error.lastEvent = value;
this.emit("error", error);
return null;
}
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) { _transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
try { try {
if (this.isAwsStream) { if (this.isAwsStream) {
this.parser.write(chunk); this.awsParser.write(chunk);
} else if (this.isGoogleStream) {
this.jsonParser.write(chunk);
} else { } else {
// We may receive multiple (or partial) SSE messages in a single chunk, // We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit separate stream events for full // so we need to buffer and emit separate stream events for full
// messages so we can parse/transform them properly. // messages so we can parse/transform them properly.
const str = chunk.toString("utf8"); const str = this.decoder.write(chunk);
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/); const fullMessages = (this.partialMessage + str).split(
/\r\r|\n\n|\r\n\r\n/
);
this.partialMessage = fullMessages.pop() || ""; this.partialMessage = fullMessages.pop() || "";
for (const message of fullMessages) { for (const message of fullMessages) {
// Mixing line endings will break some clients and our request queue // Mixing line endings will break some clients and our request queue
// will have already sent \n for heartbeats, so we need to normalize // will have already sent \n for heartbeats, so we need to normalize
// to \n. // to \n.
this.push(message.replace(/\r\n/g, "\n") + "\n\n"); this.push(message.replace(/\r\n?/g, "\n") + "\n\n");
} }
} }
callback(); callback();
} catch (error) { } catch (error) {
error.lastEvent = chunk?.toString();
this.emit("error", error); this.emit("error", error);
callback(error); callback(error);
} }
} }
} }
function getFakeErrorCompletion(type: string, message: string) {
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
const fakeEvent = JSON.stringify({
log_id: "aws-proxy-sse-message",
stop_reason: type,
completion:
"\nProxy encountered an error during streaming response.\n" + content,
truncated: false,
stop: null,
model: "",
});
return ["event: completion", `data: ${fakeEvent}\n\n`].join("\n");
}
@@ -0,0 +1,76 @@
import { StreamingCompletionTransformer } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "google-ai-to-openai",
});
type GoogleAIStreamEvent = {
candidates: {
content: { parts: { text: string }[]; role: string };
finishReason?: "STOP" | "MAX_TOKENS" | "SAFETY" | "RECITATION" | "OTHER";
index: number;
tokenCount?: number;
safetyRatings: { category: string; probability: string }[];
}[];
};
/**
* Transforms an incoming Google AI SSE to an equivalent OpenAI
* chat.completion.chunk SSE.
*/
export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
const { data, index } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
const parts = completionEvent.candidates[0].content.parts;
let content = parts[0]?.text ?? "";
// If this is the first chunk, try stripping speaker names from the response
// e.g. "John: Hello" -> "Hello"
if (index === 0) {
content = content.replace(/^(.*?): /, "").trim();
}
const newEvent = {
id: "goo-" + params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: 0,
delta: { content },
finish_reason: completionEvent.candidates[0].finishReason ?? null,
},
],
};
return { position: -1, event: newEvent };
};
function asCompletion(event: ServerSentEvent): GoogleAIStreamEvent | null {
try {
const parsed = JSON.parse(event.data) as GoogleAIStreamEvent;
if (parsed.candidates?.length > 0) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid event");
}
return null;
}
+116
View File
@@ -0,0 +1,116 @@
import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
import {
getMistralAIModelFamily,
MistralAIModelFamily,
ModelFamily,
} from "../shared/models";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
// https://docs.mistral.ai/platform/endpoints
export const KNOWN_MISTRAL_AI_MODELS = [
"mistral-tiny",
"mistral-small",
"mistral-medium",
];
let modelsCache: any = null;
let modelsCacheTime = 0;
export function generateModelList(models = KNOWN_MISTRAL_AI_MODELS) {
let available = new Set<MistralAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "mistral-ai") continue;
key.modelFamilies.forEach((family) =>
available.add(family as MistralAIModelFamily)
);
}
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
available = new Set([...available].filter((x) => allowed.has(x)));
return models
.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "mistral-ai",
}))
.filter((model) => available.has(getMistralAIModelFamily(model.id)));
}
const handleModelRequest: RequestHandler = (_req, res) => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
const result = generateModelList();
modelsCache = { object: "list", data: result };
modelsCacheTime = new Date().getTime();
res.status(200).json(modelsCache);
};
const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
const mistralAIProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.mistral.ai",
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([mistralAIResponseHandler]),
error: handleProxyError,
},
}),
});
const mistralAIRouter = Router();
mistralAIRouter.get("/v1/models", handleModelRequest);
// General chat completion endpoint.
mistralAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "mistral-ai",
outApi: "mistral-ai",
service: "mistral-ai",
}),
mistralAIProxy
);
export const mistralAI = mistralAIRouter;
-1
View File
@@ -17,7 +17,6 @@ import {
} from "./middleware/response"; } from "./middleware/response";
import { generateModelList } from "./openai"; import { generateModelList } from "./openai";
import { import {
mirrorGeneratedImage,
OpenAIImageGenerationResult, OpenAIImageGenerationResult,
} from "../shared/file-storage/mirror-generated-image"; } from "../shared/file-storage/mirror-generated-image";
-170
View File
@@ -1,170 +0,0 @@
import { Request, RequestHandler, Router } from "express";
import * as http from "http";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
forceModel,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.googlePalmKey) return { object: "list", data: [] };
const bisonVariants = ["text-bison-001"];
const models = bisonVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "palm",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const palmResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming Google PaLM response to OpenAI format");
body = transformPalmResponse(body, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
// TODO: PaLM has no streaming capability which will pose a problem here if
// requests wait in the queue for too long. Probably need to fake streaming
// and return the entire completion in one stream event using the other
// response handler.
res.status(200).json(body);
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformPalmResponse(
palmRespBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "plm-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: palmRespBody.candidates[0].output,
},
finish_reason: null, // palm doesn't return this
index: 0,
},
],
};
}
function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
if (req.body.stream) {
throw new Error("Google PaLM API doesn't support streaming requests");
}
// PaLM API specifies the model in the URL path, not the request body. This
// doesn't work well with our rewriter architecture, so we need to manually
// fix it here.
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateText
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateMessage
// The chat api (generateMessage) is not very useful at this time as it has
// few params and no adjustable safety settings.
proxyReq.path = proxyReq.path.replace(
/^\/v1\/chat\/completions/,
`/v1beta2/models/${req.body.model}:generateText`
);
}
const googlePalmProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://generativelanguage.googleapis.com",
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([palmResponseHandler]),
error: handleProxyError,
},
}),
});
const palmRouter = Router();
palmRouter.get("/v1/models", handleModelRequest);
// OpenAI-to-Google PaLM compatibility endpoint.
palmRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-palm", service: "google-palm" },
{ afterTransform: [forceModel("text-bison-001")] }
),
googlePalmProxy
);
export const googlePalm = palmRouter;
+56 -31
View File
@@ -14,8 +14,12 @@
import crypto from "crypto"; import crypto from "crypto";
import type { Handler, Request } from "express"; import type { Handler, Request } from "express";
import { keyPool } from "../shared/key-management"; import { keyPool } from "../shared/key-management";
import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models"; import {
import { buildFakeSse, initializeSseStream } from "../shared/streaming"; getModelFamilyForRequest,
MODEL_FAMILIES,
ModelFamily,
} from "../shared/models";
import { makeCompletionSSE, initializeSseStream } from "../shared/streaming";
import { logger } from "../logger"; import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit"; import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request"; import { RequestPreprocessor } from "./middleware/request";
@@ -37,6 +41,7 @@ const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "50");
const PAYLOAD_SCALE_FACTOR = parseFloat( const PAYLOAD_SCALE_FACTOR = parseFloat(
process.env.PAYLOAD_SCALE_FACTOR ?? "6" process.env.PAYLOAD_SCALE_FACTOR ?? "6"
); );
const QUEUE_JOIN_TIMEOUT = 5000;
/** /**
* Returns an identifier for a request. This is used to determine if a * Returns an identifier for a request. This is used to determine if a
@@ -60,7 +65,7 @@ const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip); const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
export function enqueue(req: Request) { export async function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length; const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
let isGuest = req.user?.token === undefined; let isGuest = req.user?.token === undefined;
@@ -92,7 +97,7 @@ export function enqueue(req: Request) {
if (stream === "true" || stream === true || req.isStreaming) { if (stream === "true" || stream === true || req.isStreaming) {
const res = req.res!; const res = req.res!;
if (!res.headersSent) { if (!res.headersSent) {
initStreaming(req); await initStreaming(req);
} }
registerHeartbeat(req); registerHeartbeat(req);
} else if (getProxyLoad() > LOAD_THRESHOLD) { } else if (getProxyLoad() > LOAD_THRESHOLD) {
@@ -119,7 +124,9 @@ export function enqueue(req: Request) {
if (req.retryCount ?? 0 > 0) { if (req.retryCount ?? 0 > 0) {
req.log.info({ retries: req.retryCount }, `Enqueued request for retry.`); req.log.info({ retries: req.retryCount }, `Enqueued request for retry.`);
} else { } else {
req.log.info(`Enqueued new request.`); const size = req.socket.bytesRead;
const endpoint = req.url?.split("?")[0];
req.log.info({ size, endpoint }, `Enqueued new request.`);
} }
} }
@@ -189,10 +196,10 @@ function processQueue() {
reqs.filter(Boolean).forEach((req) => { reqs.filter(Boolean).forEach((req) => {
if (req?.proceed) { if (req?.proceed) {
const modelFamily = getModelFamilyForRequest(req!); const modelFamily = getModelFamilyForRequest(req!);
req.log.info({ req.log.info(
retries: req.retryCount, { retries: req.retryCount, partition: modelFamily },
partition: modelFamily, `Dequeuing request.`
}, `Dequeuing request.`); );
req.proceed(); req.proceed();
} }
}); });
@@ -327,7 +334,7 @@ export function createQueueMiddleware({
beforeProxy?: RequestPreprocessor; beforeProxy?: RequestPreprocessor;
proxyMiddleware: Handler; proxyMiddleware: Handler;
}): Handler { }): Handler {
return (req, res, next) => { return async (req, res, next) => {
req.proceed = async () => { req.proceed = async () => {
if (beforeProxy) { if (beforeProxy) {
try { try {
@@ -345,7 +352,7 @@ export function createQueueMiddleware({
}; };
try { try {
enqueue(req); await enqueue(req);
} catch (err: any) { } catch (err: any) {
req.res!.status(429).json({ req.res!.status(429).json({
type: "proxy_error", type: "proxy_error",
@@ -367,8 +374,15 @@ function killQueuedRequest(req: Request) {
try { try {
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`; const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`;
if (res.headersSent) { if (res.headersSent) {
const fakeErrorEvent = buildFakeSse("proxy queue error", message, req); const event = makeCompletionSSE({
res.write(fakeErrorEvent); format: req.inboundApi,
title: "Proxy queue error",
message,
reqId: String(req.id),
model: req.body?.model,
});
res.write(event);
res.write(`data: [DONE]\n\n`);
res.end(); res.end();
} else { } else {
res.status(500).json({ error: message }); res.status(500).json({ error: message });
@@ -378,20 +392,39 @@ function killQueuedRequest(req: Request) {
} }
} }
function initStreaming(req: Request) { async function initStreaming(req: Request) {
const res = req.res!; const res = req.res!;
initializeSseStream(res); initializeSseStream(res);
if (req.query.badSseParser) { const joinMsg = `: joining queue at position ${
// Some clients have a broken SSE parser that doesn't handle comments queue.length
// correctly. These clients can pass ?badSseParser=true to }\n\n${getHeartbeatPayload()}`;
// disable comments in the SSE stream.
res.write(getHeartbeatPayload());
return;
}
res.write(`: joining queue at position ${queue.length}\n\n`); let drainTimeout: NodeJS.Timeout;
res.write(getHeartbeatPayload()); const welcome = new Promise<void>((resolve, reject) => {
const onDrain = () => {
clearTimeout(drainTimeout);
req.log.debug(`Client finished consuming join message.`);
res.off("drain", onDrain);
resolve();
};
drainTimeout = setTimeout(() => {
res.off("drain", onDrain);
res.destroy();
reject(new Error("Unreponsive streaming client; killing connection"));
}, QUEUE_JOIN_TIMEOUT);
if (!res.write(joinMsg)) {
req.log.warn("Kernel buffer is full; holding client request.");
res.once("drain", onDrain);
} else {
clearTimeout(drainTimeout);
resolve();
}
});
await welcome;
} }
/** /**
@@ -451,14 +484,6 @@ function removeProxyMiddlewareEventListeners(req: Request) {
export function registerHeartbeat(req: Request) { export function registerHeartbeat(req: Request) {
const res = req.res!; const res = req.res!;
const currentSize = getHeartbeatSize();
req.log.debug({
currentSize,
HEARTBEAT_INTERVAL,
PAYLOAD_SCALE_FACTOR,
MAX_HEARTBEAT_SIZE,
}, "Joining queue with heartbeat.");
let isBufferFull = false; let isBufferFull = false;
let bufferFullCount = 0; let bufferFullCount = 0;
req.heartbeatInterval = setInterval(() => { req.heartbeatInterval = setInterval(() => {
+6 -4
View File
@@ -4,7 +4,8 @@ import { checkRisuToken } from "./check-risu-token";
import { openai } from "./openai"; import { openai } from "./openai";
import { openaiImage } from "./openai-image"; import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic"; import { anthropic } from "./anthropic";
import { googlePalm } from "./palm"; import { googleAI } from "./google-ai";
import { mistralAI } from "./mistral-ai";
import { aws } from "./aws"; import { aws } from "./aws";
import { azure } from "./azure"; import { azure } from "./azure";
@@ -18,8 +19,8 @@ proxyRouter.use((req, _res, next) => {
next(); next();
}); });
proxyRouter.use( proxyRouter.use(
express.json({ limit: "10mb" }), express.json({ limit: "1mb" }),
express.urlencoded({ extended: true, limit: "10mb" }) express.urlencoded({ extended: true, limit: "1mb" })
); );
proxyRouter.use(gatekeeper); proxyRouter.use(gatekeeper);
proxyRouter.use(checkRisuToken); proxyRouter.use(checkRisuToken);
@@ -31,7 +32,8 @@ proxyRouter.use((req, _res, next) => {
proxyRouter.use("/openai", addV1, openai); proxyRouter.use("/openai", addV1, openai);
proxyRouter.use("/openai-image", addV1, openaiImage); proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic); proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-palm", addV1, googlePalm); proxyRouter.use("/google-ai", addV1, googleAI);
proxyRouter.use("/mistral-ai", addV1, mistralAI);
proxyRouter.use("/aws/claude", addV1, aws); proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/azure/openai", addV1, azure); proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage. // Redirect browser requests to the homepage.
+9 -3
View File
@@ -12,7 +12,8 @@ import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir";
import { keyPool } from "./shared/key-management"; import { keyPool } from "./shared/key-management";
import { adminRouter } from "./admin/routes"; import { adminRouter } from "./admin/routes";
import { proxyRouter } from "./proxy/routes"; import { proxyRouter } from "./proxy/routes";
import { handleInfoPage } from "./info-page"; import { handleInfoPage, renderPage } from "./info-page";
import { buildInfo } from "./service-info";
import { logQueue } from "./shared/prompt-logging"; import { logQueue } from "./shared/prompt-logging";
import { start as startRequestQueue } from "./proxy/queue"; import { start as startRequestQueue } from "./proxy/queue";
import { init as initUserStore } from "./shared/users/user-store"; import { init as initUserStore } from "./shared/users/user-store";
@@ -53,6 +54,10 @@ app.use(
// a load balancer/reverse proxy, which is necessary to determine request IP // a load balancer/reverse proxy, which is necessary to determine request IP
// addresses correctly. // addresses correctly.
app.set("trust proxy", true); app.set("trust proxy", true);
app.use((req, _res, next) => {
req.log.info({ ip: req.ip, forwardedFor: req.get("x-forwarded-for") });
next();
});
app.set("view engine", "ejs"); app.set("view engine", "ejs");
app.set("views", [ app.set("views", [
@@ -67,13 +72,14 @@ app.get("/health", (_req, res) => res.sendStatus(200));
app.use(cors()); app.use(cors());
app.use(checkOrigin); app.use(checkOrigin);
// routes
app.get("/", handleInfoPage); app.get("/", handleInfoPage);
app.get("/status", (req, res) => {
res.json(buildInfo(req.protocol + "://" + req.get("host"), false));
});
app.use("/admin", adminRouter); app.use("/admin", adminRouter);
app.use("/proxy", proxyRouter); app.use("/proxy", proxyRouter);
app.use("/user", userRouter); app.use("/user", userRouter);
// 500 and 404
app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => { app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => {
if (err.status) { if (err.status) {
res.status(err.status).json({ error: err.message }); res.status(err.status).json({ error: err.message });
+441
View File
@@ -0,0 +1,441 @@
/** Calculates and returns stats about the service. */
import { config, listConfig } from "./config";
import {
AnthropicKey,
AwsBedrockKey,
AzureOpenAIKey,
GoogleAIKey,
keyPool,
OpenAIKey,
} from "./shared/key-management";
import {
AnthropicModelFamily,
assertIsKnownModelFamily,
AwsBedrockModelFamily,
AzureOpenAIModelFamily,
GoogleAIModelFamily,
LLM_SERVICES,
LLMService,
MistralAIModelFamily,
MODEL_FAMILY_SERVICE,
ModelFamily,
OpenAIModelFamily,
} from "./shared/models";
import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
import { getUniqueIps } from "./proxy/rate-limit";
import { assertNever } from "./shared/utils";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
const CACHE_TTL = 2000;
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
k.service === "openai";
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic";
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
k.service === "google-ai";
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
k.service === "mistral-ai";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
/** Stats aggregated across all keys for a given service. */
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
/** Stats aggregated across all keys for a given model family. */
type ModelAggregates = {
active: number;
trial?: number;
revoked?: number;
overQuota?: number;
pozzed?: number;
awsLogged?: number;
queued: number;
queueTime: string;
tokens: number;
};
/** All possible combinations of model family and aggregate type. */
type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`;
type AllStats = {
proompts: number;
tokens: number;
tokenCost: number;
} & { [modelFamily in ModelFamily]?: ModelAggregates } & {
[service in LLMService as `${service}__${ServiceAggregate}`]?: number;
};
type BaseFamilyInfo = {
usage?: string;
activeKeys: number;
revokedKeys?: number;
proomptersInQueue?: number;
estimatedQueueTime?: string;
};
type OpenAIInfo = BaseFamilyInfo & {
trialKeys?: number;
overQuotaKeys?: number;
};
type AnthropicInfo = BaseFamilyInfo & { pozzedKeys?: number };
type AwsInfo = BaseFamilyInfo & { privacy?: string };
// prettier-ignore
export type ServiceInfo = {
uptime: number;
endpoints: {
openai?: string;
openai2?: string;
"openai-image"?: string;
anthropic?: string;
"google-ai"?: string;
"mistral-ai"?: string;
aws?: string;
azure?: string;
};
proompts?: number;
tookens?: string;
proomptersNow?: number;
status?: string;
config: ReturnType<typeof listConfig>;
build: string;
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
& { [f in AwsBedrockModelFamily]?: AwsInfo }
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
// https://stackoverflow.com/a/66661477
// type DeepKeyOf<T> = (
// [T] extends [never]
// ? ""
// : T extends object
// ? {
// [K in Exclude<keyof T, symbol>]: `${K}${DotPrefix<DeepKeyOf<T[K]>>}`;
// }[Exclude<keyof T, symbol>]
// : ""
// ) extends infer D
// ? Extract<D, string>
// : never;
// type DotPrefix<T extends string> = T extends "" ? "" : `.${T}`;
// type ServiceInfoPath = `{${DeepKeyOf<ServiceInfo>}}`;
const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
openai: {
openai: `%BASE%/openai`,
openai2: `%BASE%/openai/turbo-instruct`,
"openai-image": `%BASE%/openai-image`,
},
anthropic: {
anthropic: `%BASE%/anthropic`,
},
"google-ai": {
"google-ai": `%BASE%/google-ai`,
},
"mistral-ai": {
"mistral-ai": `%BASE%/mistral-ai`,
},
aws: {
aws: `%BASE%/aws/claude`,
},
azure: {
azure: `%BASE%/azure/openai`,
},
};
const modelStats = new Map<ModelAggregateKey, number>();
const serviceStats = new Map<keyof AllStats, number>();
let cachedInfo: ServiceInfo | undefined;
let cacheTime = 0;
export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo {
if (cacheTime + CACHE_TTL > Date.now()) return cachedInfo!;
const keys = keyPool.list();
const accessibleFamilies = new Set(
keys
.flatMap((k) => k.modelFamilies)
.filter((f) => config.allowedModelFamilies.includes(f))
.concat("turbo")
);
modelStats.clear();
serviceStats.clear();
keys.forEach(addKeyToAggregates);
const endpoints = getEndpoints(baseUrl, accessibleFamilies);
const trafficStats = getTrafficStats();
const { serviceInfo, modelFamilyInfo } =
getServiceModelStats(accessibleFamilies);
const status = getStatus();
if (config.staticServiceInfo && !forAdmin) {
delete trafficStats.proompts;
delete trafficStats.tookens;
delete trafficStats.proomptersNow;
for (const family of Object.keys(modelFamilyInfo)) {
assertIsKnownModelFamily(family);
delete modelFamilyInfo[family]?.proomptersInQueue;
delete modelFamilyInfo[family]?.estimatedQueueTime;
delete modelFamilyInfo[family]?.usage;
}
}
return (cachedInfo = {
uptime: Math.floor(process.uptime()),
endpoints,
...trafficStats,
...serviceInfo,
status,
...modelFamilyInfo,
config: listConfig(),
build: process.env.BUILD_INFO || "dev",
});
}
function getStatus() {
if (!config.checkKeys) return "Key checking is disabled.";
let unchecked = 0;
for (const service of LLM_SERVICES) {
unchecked += serviceStats.get(`${service}__uncheckedKeys`) || 0;
}
return unchecked ? `Checking ${unchecked} keys...` : undefined;
}
function getEndpoints(baseUrl: string, accessibleFamilies: Set<ModelFamily>) {
const endpoints: Record<string, string> = {};
for (const service of LLM_SERVICES) {
for (const [name, url] of Object.entries(SERVICE_ENDPOINTS[service])) {
endpoints[name] = url.replace("%BASE%", baseUrl);
}
if (service === "openai" && !accessibleFamilies.has("dall-e")) {
delete endpoints["openai-image"];
}
}
return endpoints;
}
type TrafficStats = Pick<ServiceInfo, "proompts" | "tookens" | "proomptersNow">;
function getTrafficStats(): TrafficStats {
const tokens = serviceStats.get("tokens") || 0;
const tokenCost = serviceStats.get("tokenCost") || 0;
return {
proompts: serviceStats.get("proompts") || 0,
tookens: `${prettyTokens(tokens)}${getCostSuffix(tokenCost)}`,
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
};
}
function getServiceModelStats(accessibleFamilies: Set<ModelFamily>) {
const serviceInfo: {
[s in LLMService as `${s}${"Keys" | "Orgs"}`]?: number;
} = {};
const modelFamilyInfo: { [f in ModelFamily]?: BaseFamilyInfo } = {};
for (const service of LLM_SERVICES) {
const hasKeys = serviceStats.get(`${service}__keys`) || 0;
if (!hasKeys) continue;
serviceInfo[`${service}Keys`] = hasKeys;
accessibleFamilies.forEach((f) => {
if (MODEL_FAMILY_SERVICE[f] === service) {
modelFamilyInfo[f] = getInfoForFamily(f);
}
});
if (service === "openai" && config.checkKeys) {
serviceInfo.openaiOrgs = getUniqueOpenAIOrgs(keyPool.list());
}
}
return { serviceInfo, modelFamilyInfo };
}
function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) {
const orgIds = new Set(
keys.filter((k) => k.service === "openai").map((k: any) => k.organizationId)
);
return orgIds.size;
}
function increment<T extends keyof AllStats | ModelAggregateKey>(
map: Map<T, number>,
key: T,
delta = 1
) {
map.set(key, (map.get(key) || 0) + delta);
}
function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "proompts", k.promptCount);
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
increment(serviceStats, "mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
let sumTokens = 0;
let sumCost = 0;
switch (k.service) {
case "openai":
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
increment(
serviceStats,
"openai__uncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
});
break;
case "azure":
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
});
break;
case "anthropic": {
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
const family = "claude";
sumTokens += k.claudeTokens;
sumCost += getTokenCostUsd(family, k.claudeTokens);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k.claudeTokens);
increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0);
increment(
serviceStats,
"anthropic__uncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
break;
}
case "google-ai": {
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
const family = "gemini-pro";
sumTokens += k["gemini-proTokens"];
sumCost += getTokenCostUsd(family, k["gemini-proTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
break;
}
case "mistral-ai": {
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
break;
}
case "aws": {
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
const family = "aws-claude";
sumTokens += k["aws-claudeTokens"];
sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
// Ignore revoked keys for aws logging stats, but include keys where the
// logging status is unknown.
const countAsLogged =
k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled";
increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0);
break;
}
default:
assertNever(k.service);
}
increment(serviceStats, "tokens", sumTokens);
increment(serviceStats, "tokenCost", sumCost);
}
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const tokens = modelStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = {
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
activeKeys: modelStats.get(`${family}__active`) || 0,
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
};
// Add service-specific stats to the info object.
if (config.checkKeys) {
const service = MODEL_FAMILY_SERVICE[family];
switch (service) {
case "openai":
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
// Delete trial/revoked keys for non-turbo families.
// Trials are turbo 99% of the time, and if a key is invalid we don't
// know what models it might have had assigned to it.
if (family !== "turbo") {
delete info.trialKeys;
delete info.revokedKeys;
}
break;
case "anthropic":
info.pozzedKeys = modelStats.get(`${family}__pozzed`) || 0;
break;
case "aws":
const logged = modelStats.get(`${family}__awsLogged`) || 0;
if (logged > 0) {
info.privacy = config.allowAwsLogging
? `${logged} active keys are potentially logged.`
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
}
break;
}
}
// Add queue stats to the info object.
const queue = getQueueInformation(family);
info.proomptersInQueue = queue.proomptersInQueue;
info.estimatedQueueTime = queue.estimatedQueueTime;
return info;
}
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
function getQueueInformation(partition: ModelFamily) {
const waitMs = getEstimatedWaitTime(partition);
const waitTime =
waitMs < 60000
? `${Math.round(waitMs / 1000)}sec`
: `${Math.round(waitMs / 60000)}min, ${Math.round(
(waitMs % 60000) / 1000
)}sec`;
return {
proomptersInQueue: getQueueLength(partition),
estimatedQueueTime: waitMs > 2000 ? waitTime : "no wait",
};
}
+5 -3
View File
@@ -1,8 +1,10 @@
// noinspection JSUnusedGlobalSymbols,ES6UnusedImports
import type { HttpRequest } from "@smithy/types"; import type { HttpRequest } from "@smithy/types";
import { Express } from "express-serve-static-core"; import { Express } from "express-serve-static-core";
import { APIFormat, Key, LLMService } from "../shared/key-management"; import { APIFormat, Key } from "./key-management";
import { User } from "../shared/users/schema"; import { User } from "./users/schema";
import { ModelFamily } from "../shared/models"; import { LLMService, ModelFamily } from "./models";
declare global { declare global {
namespace Express { namespace Express {
@@ -48,20 +48,20 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
protected handleAxiosError(key: AnthropicKey, error: AxiosError) { protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
if (error.response && AnthropicKeyChecker.errorIsAnthropicAPIError(error)) { if (error.response && AnthropicKeyChecker.errorIsAnthropicAPIError(error)) {
const { status, data } = error.response; const { status, data } = error.response;
if (status === 401) { if (status === 401 || status === 403) {
this.log.warn( this.log.warn(
{ key: key.hash, error: data }, { key: key.hash, error: data },
"Key is invalid or revoked. Disabling key." "Key is invalid or revoked. Disabling key."
); );
this.updateKey(key.hash, { isDisabled: true, isRevoked: true }); this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
} else if (status === 429) { }
else if (status === 429) {
switch (data.error.type) { switch (data.error.type) {
case "rate_limit_error": case "rate_limit_error":
this.log.warn( this.log.warn(
{ key: key.hash, error: error.message }, { key: key.hash, error: error.message },
"Key is rate limited. Rechecking in 10 seconds." "Key is rate limited. Rechecking in 10 seconds."
); );
0;
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000); const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
this.updateKey(key.hash, { lastChecked: next }); this.updateKey(key.hash, { lastChecked: next });
break; break;
+41 -26
View File
@@ -36,34 +36,10 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
protected async testKeyOrFail(key: AzureOpenAIKey) { protected async testKeyOrFail(key: AzureOpenAIKey) {
const model = await this.testModel(key); const model = await this.testModel(key);
this.log.info( this.log.info({ key: key.hash, deploymentModel: model }, "Checked key.");
{ key: key.hash, deploymentModel: model },
"Checked key."
);
this.updateKey(key.hash, { modelFamilies: [model] }); this.updateKey(key.hash, { modelFamilies: [model] });
} }
// provided api-key header isn't valid (401)
// {
// "error": {
// "code": "401",
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
// }
// }
// api key correct but deployment id is wrong (404)
// {
// "error": {
// "code": "DeploymentNotFound",
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
// }
// }
// resource name is wrong (node will throw ENOTFOUND)
// rate limited (429)
// TODO: try to reproduce this
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) { protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) { if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
const data = error.response.data; const data = error.response.data;
@@ -88,6 +64,20 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
isDisabled: true, isDisabled: true,
isRevoked: true, isRevoked: true,
}); });
case "429":
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is rate limited. Rechecking key in 1 minute."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
setTimeout(async () => {
this.log.info(
{ key: key.hash },
"Rechecking Azure key after rate limit."
);
await this.checkKey(key);
}, 1000 * 60);
return;
default: default:
this.log.error( this.log.error(
{ key: key.hash, errorType, error: error.response.data, status }, { key: key.hash, errorType, error: error.response.data, status },
@@ -129,7 +119,32 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
headers: { "Content-Type": "application/json", "api-key": apiKey }, headers: { "Content-Type": "application/json", "api-key": apiKey },
}); });
return getAzureOpenAIModelFamily(data.model); const family = getAzureOpenAIModelFamily(data.model);
// Azure returns "gpt-4" even for GPT-4 Turbo, so we need further checks.
// Otherwise we can use the model family Azure returned.
if (family !== "azure-gpt4") {
return family;
}
// Try to send an oversized prompt. GPT-4 Turbo can handle this but regular
// GPT-4 will return a Bad Request error.
const contextText = {
max_tokens: 9000,
stream: false,
temperature: 0,
seed: 0,
messages: [{ role: "user", content: "" }],
};
const { data: contextTest, status } = await axios.post(url, contextText, {
headers: { "Content-Type": "application/json", "api-key": apiKey },
validateStatus: (status) => status === 400 || status === 200,
});
const code = contextTest.error?.code;
this.log.debug({ code, status }, "Performed Azure GPT4 context size test.");
if (code === "context_length_exceeded") return "azure-gpt4";
return "azure-gpt4-turbo";
} }
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> { static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
@@ -6,7 +6,6 @@ import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models"; import { getAzureOpenAIModelFamily } from "../../models";
import { OpenAIModel } from "../openai/provider"; import { OpenAIModel } from "../openai/provider";
import { AzureOpenAIKeyChecker } from "./checker"; import { AzureOpenAIKeyChecker } from "./checker";
import { AwsKeyChecker } from "../aws/checker";
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">; export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
@@ -2,13 +2,17 @@ import crypto from "crypto";
import { Key, KeyProvider } from ".."; import { Key, KeyProvider } from "..";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models"; import type { GoogleAIModelFamily } from "../../models";
// https://developers.generativeai.google.com/models/language // Note that Google AI is not the same as Vertex AI, both are provided by Google
export type GooglePalmModel = "text-bison-001"; // but Vertex is the GCP product for enterprise. while Google AI is the
// consumer-ish product. The API is different, and keys are not compatible.
// https://ai.google.dev/docs/migrate_to_cloud
export type GooglePalmKeyUpdate = Omit< export type GoogleAIModel = "gemini-pro";
Partial<GooglePalmKey>,
export type GoogleAIKeyUpdate = Omit<
Partial<GoogleAIKey>,
| "key" | "key"
| "hash" | "hash"
| "lastUsed" | "lastUsed"
@@ -17,13 +21,13 @@ export type GooglePalmKeyUpdate = Omit<
| "rateLimitedUntil" | "rateLimitedUntil"
>; >;
type GooglePalmKeyUsage = { type GoogleAIKeyUsage = {
[K in GooglePalmModelFamily as `${K}Tokens`]: number; [K in GoogleAIModelFamily as `${K}Tokens`]: number;
}; };
export interface GooglePalmKey extends Key, GooglePalmKeyUsage { export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
readonly service: "google-palm"; readonly service: "google-ai";
readonly modelFamilies: GooglePalmModelFamily[]; readonly modelFamilies: GoogleAIModelFamily[];
/** The time at which this key was last rate limited. */ /** The time at which this key was last rate limited. */
rateLimitedAt: number; rateLimitedAt: number;
/** The time until which this key is rate limited. */ /** The time until which this key is rate limited. */
@@ -42,27 +46,27 @@ const RATE_LIMIT_LOCKOUT = 2000;
*/ */
const KEY_REUSE_DELAY = 500; const KEY_REUSE_DELAY = 500;
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> { export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
readonly service = "google-palm"; readonly service = "google-ai";
private keys: GooglePalmKey[] = []; private keys: GoogleAIKey[] = [];
private log = logger.child({ module: "key-provider", service: this.service }); private log = logger.child({ module: "key-provider", service: this.service });
constructor() { constructor() {
const keyConfig = config.googlePalmKey?.trim(); const keyConfig = config.googleAIKey?.trim();
if (!keyConfig) { if (!keyConfig) {
this.log.warn( this.log.warn(
"GOOGLE_PALM_KEY is not set. PaLM API will not be available." "GOOGLE_AI_KEY is not set. Google AI API will not be available."
); );
return; return;
} }
let bareKeys: string[]; let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) { for (const key of bareKeys) {
const newKey: GooglePalmKey = { const newKey: GoogleAIKey = {
key, key,
service: this.service, service: this.service,
modelFamilies: ["bison"], modelFamilies: ["gemini-pro"],
isDisabled: false, isDisabled: false,
isRevoked: false, isRevoked: false,
promptCount: 0, promptCount: 0,
@@ -75,11 +79,11 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
.digest("hex") .digest("hex")
.slice(0, 8)}`, .slice(0, 8)}`,
lastChecked: 0, lastChecked: 0,
bisonTokens: 0, "gemini-proTokens": 0,
}; };
this.keys.push(newKey); this.keys.push(newKey);
} }
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys."); this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
} }
public init() {} public init() {}
@@ -88,10 +92,10 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
} }
public get(_model: GooglePalmModel) { public get(_model: GoogleAIModel) {
const availableKeys = this.keys.filter((k) => !k.isDisabled); const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) { if (availableKeys.length === 0) {
throw new Error("No Google PaLM keys available"); throw new Error("No Google AI keys available");
} }
// (largely copied from the OpenAI provider, without trial key support) // (largely copied from the OpenAI provider, without trial key support)
@@ -122,14 +126,14 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return { ...selectedKey }; return { ...selectedKey };
} }
public disable(key: GooglePalmKey) { public disable(key: GoogleAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash); const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return; if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true; keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled"); this.log.warn({ key: key.hash }, "Key disabled");
} }
public update(hash: string, update: Partial<GooglePalmKey>) { public update(hash: string, update: Partial<GoogleAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!; const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
} }
@@ -142,7 +146,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === hash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key.bisonTokens += tokens; key["gemini-proTokens"] += tokens;
} }
public getLockoutPeriod() { public getLockoutPeriod() {
+6 -12
View File
@@ -1,29 +1,23 @@
import type { LLMService, ModelFamily } from "../models";
import { OpenAIModel } from "./openai/provider"; import { OpenAIModel } from "./openai/provider";
import { AnthropicModel } from "./anthropic/provider"; import { AnthropicModel } from "./anthropic/provider";
import { GooglePalmModel } from "./palm/provider"; import { GoogleAIModel } from "./google-ai/provider";
import { AwsBedrockModel } from "./aws/provider"; import { AwsBedrockModel } from "./aws/provider";
import { AzureOpenAIModel } from "./azure/provider"; import { AzureOpenAIModel } from "./azure/provider";
import { KeyPool } from "./key-pool"; import { KeyPool } from "./key-pool";
import type { ModelFamily } from "../models";
/** The request and response format used by a model's API. */ /** The request and response format used by a model's API. */
export type APIFormat = export type APIFormat =
| "openai" | "openai"
| "anthropic" | "anthropic"
| "google-palm" | "google-ai"
| "mistral-ai"
| "openai-text" | "openai-text"
| "openai-image"; | "openai-image";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService =
| "openai"
| "anthropic"
| "google-palm"
| "aws"
| "azure";
export type Model = export type Model =
| OpenAIModel | OpenAIModel
| AnthropicModel | AnthropicModel
| GooglePalmModel | GoogleAIModel
| AwsBedrockModel | AwsBedrockModel
| AzureOpenAIModel; | AzureOpenAIModel;
@@ -77,6 +71,6 @@ export interface KeyProvider<T extends Key = Key> {
export const keyPool = new KeyPool(); export const keyPool = new KeyPool();
export { AnthropicKey } from "./anthropic/provider"; export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider"; export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider"; export { GoogleAIKey } from "././google-ai/provider";
export { AwsBedrockKey } from "./aws/provider"; export { AwsBedrockKey } from "./aws/provider";
export { AzureOpenAIKey } from "./azure/provider"; export { AzureOpenAIKey } from "./azure/provider";
+12 -32
View File
@@ -4,14 +4,14 @@ import os from "os";
import schedule from "node-schedule"; import schedule from "node-schedule";
import { config } from "../../config"; import { config } from "../../config";
import { logger } from "../../logger"; import { logger } from "../../logger";
import { Key, Model, KeyProvider, LLMService } from "./index"; import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models";
import { Key, Model, KeyProvider } from "./index";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider"; import { GoogleAIKeyProvider } from "./google-ai/provider";
import { AwsBedrockKeyProvider } from "./aws/provider"; import { AwsBedrockKeyProvider } from "./aws/provider";
import { ModelFamily } from "../models";
import { assertNever } from "../utils";
import { AzureOpenAIKeyProvider } from "./azure/provider"; import { AzureOpenAIKeyProvider } from "./azure/provider";
import { MistralAIKeyProvider } from "./mistral-ai/provider";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@@ -24,7 +24,8 @@ export class KeyPool {
constructor() { constructor() {
this.keyProviders.push(new OpenAIKeyProvider()); this.keyProviders.push(new OpenAIKeyProvider());
this.keyProviders.push(new AnthropicKeyProvider()); this.keyProviders.push(new AnthropicKeyProvider());
this.keyProviders.push(new GooglePalmKeyProvider()); this.keyProviders.push(new GoogleAIKeyProvider());
this.keyProviders.push(new MistralAIKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider()); this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider()); this.keyProviders.push(new AzureOpenAIKeyProvider());
} }
@@ -82,7 +83,7 @@ export class KeyPool {
} }
public getLockoutPeriod(family: ModelFamily): number { public getLockoutPeriod(family: ModelFamily): number {
const service = this.getServiceForModelFamily(family); const service = MODEL_FAMILY_SERVICE[family];
return this.getKeyProvider(service).getLockoutPeriod(family); return this.getKeyProvider(service).getLockoutPeriod(family);
} }
@@ -119,9 +120,12 @@ export class KeyPool {
} else if (model.startsWith("claude-")) { } else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters // https://console.anthropic.com/docs/api/reference#parameters
return "anthropic"; return "anthropic";
} else if (model.includes("bison")) { } else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language // https://developers.generativeai.google.com/models/language
return "google-palm"; return "google-ai";
} else if (model.includes("mistral")) {
// https://docs.mistral.ai/platform/endpoints
return "mistral-ai";
} else if (model.startsWith("anthropic.claude")) { } else if (model.startsWith("anthropic.claude")) {
// AWS offers models from a few providers // AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
@@ -132,30 +136,6 @@ export class KeyPool {
throw new Error(`Unknown service for model '${model}'`); throw new Error(`Unknown service for model '${model}'`);
} }
private getServiceForModelFamily(modelFamily: ModelFamily): LLMService {
switch (modelFamily) {
case "gpt4":
case "gpt4-32k":
case "gpt4-turbo":
case "turbo":
case "dall-e":
return "openai";
case "claude":
return "anthropic";
case "bison":
return "google-palm";
case "aws-claude":
return "aws";
case "azure-turbo":
case "azure-gpt4":
case "azure-gpt4-32k":
case "azure-gpt4-turbo":
return "azure";
default:
assertNever(modelFamily);
}
}
private getKeyProvider(service: LLMService): KeyProvider { private getKeyProvider(service: LLMService): KeyProvider {
return this.keyProviders.find((provider) => provider.service === service)!; return this.keyProviders.find((provider) => provider.service === service)!;
} }
@@ -0,0 +1,112 @@
import axios, { AxiosError } from "axios";
import type { MistralAIModelFamily, OpenAIModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base";
import type { MistralAIKey, MistralAIKeyProvider } from "./provider";
import { getMistralAIModelFamily, getOpenAIModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const GET_MODELS_URL = "https://api.mistral.ai/v1/models";
type GetModelsResponse = {
data: [{ id: string }];
};
type MistralAIError = {
message: string;
request_id: string;
};
type UpdateFn = typeof MistralAIKeyProvider.prototype.update;
export class MistralAIKeyChecker extends KeyCheckerBase<MistralAIKey> {
constructor(keys: MistralAIKey[], updateKey: UpdateFn) {
super(keys, {
service: "mistral-ai",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
protected async testKeyOrFail(key: MistralAIKey) {
// We only need to check for provisioned models on the initial check.
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
const provisionedModels = await this.getProvisionedModels(key);
const updates = {
modelFamilies: provisionedModels,
};
this.updateKey(key.hash, updates);
}
this.log.info({ key: key.hash, models: key.modelFamilies }, "Checked key.");
}
private async getProvisionedModels(
key: MistralAIKey
): Promise<MistralAIModelFamily[]> {
const opts = { headers: MistralAIKeyChecker.getHeaders(key) };
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
const models = data.data;
const families = new Set<MistralAIModelFamily>();
models.forEach(({ id }) => families.add(getMistralAIModelFamily(id)));
// We want to update the key's model families here, but we don't want to
// update its `lastChecked` timestamp because we need to let the liveness
// check run before we can consider the key checked.
const familiesArray = [...families];
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
this.updateKey(key.hash, {
modelFamilies: familiesArray,
lastChecked: keyFromPool.lastChecked,
});
return familiesArray;
}
protected handleAxiosError(key: MistralAIKey, error: AxiosError) {
if (error.response && MistralAIKeyChecker.errorIsMistralAIError(error)) {
const { status, data } = error.response;
if (status === 401) {
this.log.warn(
{ key: key.hash, error: data },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
modelFamilies: ["mistral-tiny"],
});
} else {
this.log.error(
{ key: key.hash, status, error: data },
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
}
return;
}
this.log.error(
{ key: key.hash, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
static errorIsMistralAIError(
error: AxiosError
): error is AxiosError<MistralAIError> {
const data = error.response?.data as any;
return data?.message && data?.request_id;
}
static getHeaders(key: MistralAIKey) {
return {
Authorization: `Bearer ${key.key}`,
};
}
}
@@ -0,0 +1,210 @@
import crypto from "crypto";
import { Key, KeyProvider, Model } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
import { MistralAIKeyChecker } from "./checker";
export type MistralAIModel =
| "mistral-tiny"
| "mistral-small"
| "mistral-medium";
export type MistralAIKeyUpdate = Omit<
Partial<MistralAIKey>,
| "key"
| "hash"
| "lastUsed"
| "promptCount"
| "rateLimitedAt"
| "rateLimitedUntil"
>;
type MistralAIKeyUsage = {
[K in MistralAIModelFamily as `${K}Tokens`]: number;
};
export interface MistralAIKey extends Key, MistralAIKeyUsage {
readonly service: "mistral-ai";
readonly modelFamilies: MistralAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
readonly service = "mistral-ai";
private keys: MistralAIKey[] = [];
private checker?: MistralAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.mistralAIKey?.trim();
if (!keyConfig) {
this.log.warn(
"MISTRAL_AI_KEY is not set. Mistral AI API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: MistralAIKey = {
key,
service: this.service,
modelFamilies: ["mistral-tiny", "mistral-small", "mistral-medium"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `mst-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"mistral-tinyTokens": 0,
"mistral-smallTokens": 0,
"mistral-mediumTokens": 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded Mistral AI keys.");
}
public init() {
if (config.checkKeys) {
const updateFn = this.update.bind(this);
this.checker = new MistralAIKeyChecker(this.keys, updateFn);
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: Model) {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) {
throw new Error("No Mistral AI keys available");
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
public disable(key: MistralAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<MistralAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
const family = getMistralAIModelFamily(model);
key[`${family}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve, so all we can do is wait a bit and try
* again. We will lock the key for 2 seconds after getting a 429 before
* retrying in order to give the other requests a chance to finish.
*/
public markRateLimited(keyHash: string) {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public recheck() {}
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}
+79 -12
View File
@@ -1,8 +1,20 @@
// Don't import anything here, this is imported by config.ts // Don't import any other project files here as this is one of the first modules
// loaded and it will cause circular imports.
import pino from "pino"; import pino from "pino";
import type { Request } from "express"; import type { Request } from "express";
import { assertNever } from "./utils";
/**
* The service that a model is hosted on. Distinct from `APIFormat` because some
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
*/
export type LLMService =
| "openai"
| "anthropic"
| "google-ai"
| "mistral-ai"
| "aws"
| "azure";
export type OpenAIModelFamily = export type OpenAIModelFamily =
| "turbo" | "turbo"
@@ -11,7 +23,11 @@ export type OpenAIModelFamily =
| "gpt4-turbo" | "gpt4-turbo"
| "dall-e"; | "dall-e";
export type AnthropicModelFamily = "claude"; export type AnthropicModelFamily = "claude";
export type GooglePalmModelFamily = "bison"; export type GoogleAIModelFamily = "gemini-pro";
export type MistralAIModelFamily =
| "mistral-tiny"
| "mistral-small"
| "mistral-medium";
export type AwsBedrockModelFamily = "aws-claude"; export type AwsBedrockModelFamily = "aws-claude";
export type AzureOpenAIModelFamily = `azure-${Exclude< export type AzureOpenAIModelFamily = `azure-${Exclude<
OpenAIModelFamily, OpenAIModelFamily,
@@ -20,7 +36,8 @@ export type AzureOpenAIModelFamily = `azure-${Exclude<
export type ModelFamily = export type ModelFamily =
| OpenAIModelFamily | OpenAIModelFamily
| AnthropicModelFamily | AnthropicModelFamily
| GooglePalmModelFamily | GoogleAIModelFamily
| MistralAIModelFamily
| AwsBedrockModelFamily | AwsBedrockModelFamily
| AzureOpenAIModelFamily; | AzureOpenAIModelFamily;
@@ -33,7 +50,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"gpt4-turbo", "gpt4-turbo",
"dall-e", "dall-e",
"claude", "claude",
"bison", "gemini-pro",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"aws-claude", "aws-claude",
"azure-turbo", "azure-turbo",
"azure-gpt4", "azure-gpt4",
@@ -41,6 +61,17 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"azure-gpt4-turbo", "azure-gpt4-turbo",
] as const); ] as const);
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
arr: A & ([LLMService] extends [A[number]] ? unknown : never)
) => arr)([
"openai",
"anthropic",
"google-ai",
"mistral-ai",
"aws",
"azure",
] as const);
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4-1106(-preview)?$": "gpt4-turbo", "^gpt-4-1106(-preview)?$": "gpt4-turbo",
"^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo", "^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo",
@@ -53,7 +84,27 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^dall-e-\\d{1}$": "dall-e", "^dall-e-\\d{1}$": "dall-e",
}; };
const modelLogger = pino({ level: "debug" }).child({ module: "startup" }); export const MODEL_FAMILY_SERVICE: {
[f in ModelFamily]: LLMService;
} = {
turbo: "openai",
gpt4: "openai",
"gpt4-turbo": "openai",
"gpt4-32k": "openai",
"dall-e": "openai",
claude: "anthropic",
"aws-claude": "aws",
"azure-turbo": "azure",
"azure-gpt4": "azure",
"azure-gpt4-32k": "azure",
"azure-gpt4-turbo": "azure",
"gemini-pro": "google-ai",
"mistral-tiny": "mistral-ai",
"mistral-small": "mistral-ai",
"mistral-medium": "mistral-ai",
};
pino({ level: "debug" }).child({ module: "startup" });
export function getOpenAIModelFamily( export function getOpenAIModelFamily(
model: string, model: string,
@@ -70,10 +121,19 @@ export function getClaudeModelFamily(model: string): ModelFamily {
return "claude"; return "claude";
} }
export function getGooglePalmModelFamily(model: string): ModelFamily { export function getGoogleAIModelFamily(_model: string): ModelFamily {
if (model.match(/^\w+-bison-\d{3}$/)) return "bison"; return "gemini-pro";
modelLogger.warn({ model }, "Could not determine Google PaLM model family"); }
return "bison";
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
switch (model) {
case "mistral-tiny":
case "mistral-small":
case "mistral-medium":
return model;
default:
return "mistral-tiny";
}
} }
export function getAwsBedrockModelFamily(_model: string): ModelFamily { export function getAwsBedrockModelFamily(_model: string): ModelFamily {
@@ -130,8 +190,11 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
case "openai-image": case "openai-image":
modelFamily = getOpenAIModelFamily(model); modelFamily = getOpenAIModelFamily(model);
break; break;
case "google-palm": case "google-ai":
modelFamily = getGooglePalmModelFamily(model); modelFamily = getGoogleAIModelFamily(model);
break;
case "mistral-ai":
modelFamily = getMistralAIModelFamily(model);
break; break;
default: default:
assertNever(req.outboundApi); assertNever(req.outboundApi);
@@ -140,3 +203,7 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
return (req.modelFamily = modelFamily); return (req.modelFamily = modelFamily);
} }
function assertNever(x: never): never {
throw new Error(`Called assertNever with argument ${x}.`);
}
+15
View File
@@ -1,3 +1,4 @@
import { config } from "../config";
import { ModelFamily } from "./models"; import { ModelFamily } from "./models";
// technically slightly underestimates, because completion tokens cost more // technically slightly underestimates, because completion tokens cost more
@@ -24,6 +25,15 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
case "claude": case "claude":
cost = 0.00001102; cost = 0.00001102;
break; break;
case "mistral-tiny":
cost = 0.00000031;
break;
case "mistral-small":
cost = 0.00000132;
break;
case "mistral-medium":
cost = 0.0000055;
break;
} }
return cost * Math.max(0, tokens); return cost * Math.max(0, tokens);
} }
@@ -40,3 +50,8 @@ export function prettyTokens(tokens: number): string {
return (tokens / 1000000000).toFixed(3) + "b"; return (tokens / 1000000000).toFixed(3) + "b";
} }
} }
export function getCostSuffix(cost: number) {
if (!config.showTokenCosts) return "";
return ` ($${cost.toFixed(2)})`;
}
+55 -24
View File
@@ -1,6 +1,7 @@
import { Request, Response } from "express"; import { Response } from "express";
import { IncomingMessage } from "http"; import { IncomingMessage } from "http";
import { assertNever } from "./utils"; import { assertNever } from "./utils";
import { APIFormat } from "./key-management";
export function initializeSseStream(res: Response) { export function initializeSseStream(res: Response) {
res.statusCode = 200; res.statusCode = 200;
@@ -39,54 +40,84 @@ export function copySseResponseHeaders(
* that the request is being proxied to. Used to send error messages to the * that the request is being proxied to. Used to send error messages to the
* client in the middle of a streaming request. * client in the middle of a streaming request.
*/ */
export function buildFakeSse(type: string, string: string, req: Request) { export function makeCompletionSSE({
let fakeEvent; format,
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`; title,
message,
obj,
reqId,
model = "unknown",
}: {
format: APIFormat;
title: string;
message: string;
obj?: object;
reqId: string | number | object;
model?: string;
}) {
const id = String(reqId);
const content = `\n\n**${title}**\n${message}${
obj ? `\n\`\`\`\n${JSON.stringify(obj, null, 2)}\n\`\`\`\n` : ""
}`;
switch (req.inboundApi) { let event;
switch (format) {
case "openai": case "openai":
fakeEvent = { case "mistral-ai":
id: "chatcmpl-" + req.id, event = {
id: "chatcmpl-" + id,
object: "chat.completion.chunk", object: "chat.completion.chunk",
created: Date.now(), created: Date.now(),
model: req.body?.model, model,
choices: [{ delta: { content }, index: 0, finish_reason: type }], choices: [{ delta: { content }, index: 0, finish_reason: title }],
}; };
break; break;
case "openai-text": case "openai-text":
fakeEvent = { event = {
id: "cmpl-" + req.id, id: "cmpl-" + id,
object: "text_completion", object: "text_completion",
created: Date.now(), created: Date.now(),
choices: [ choices: [
{ text: content, index: 0, logprobs: null, finish_reason: type }, { text: content, index: 0, logprobs: null, finish_reason: title },
], ],
model: req.body?.model, model,
}; };
break; break;
case "anthropic": case "anthropic":
fakeEvent = { event = {
completion: content, completion: content,
stop_reason: type, stop_reason: title,
truncated: false, // I've never seen this be true truncated: false,
stop: null, stop: null,
model: req.body?.model, model,
log_id: "proxy-req-" + req.id, log_id: "proxy-req-" + id,
}; };
break; break;
case "google-palm": case "google-ai":
return JSON.stringify({
candidates: [
{
content: { parts: [{ text: content }], role: "model" },
finishReason: title,
index: 0,
tokenCount: null,
safetyRatings: [],
},
],
});
case "openai-image": case "openai-image":
throw new Error(`SSE not supported for ${req.inboundApi} requests`); throw new Error(`SSE not supported for ${format} requests`);
default: default:
assertNever(req.inboundApi); assertNever(format);
} }
if (req.inboundApi === "anthropic") { if (format === "anthropic") {
return ( return (
["event: completion", `data: ${JSON.stringify(fakeEvent)}`].join("\n") + ["event: completion", `data: ${JSON.stringify(event)}`].join("\n") +
"\n\n" "\n\n"
); );
} }
return `data: ${JSON.stringify(fakeEvent)}\n\n`; return `data: ${JSON.stringify(event)}\n\n`;
} }
File diff suppressed because one or more lines are too long
+45
View File
@@ -0,0 +1,45 @@
import { MistralAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload.js";
import * as tokenizer from "./mistral-tokenizer-js";
export function init() {
tokenizer.initializemistralTokenizer();
return true;
}
export function getTokenCount(prompt: MistralAIChatMessage[] | string) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
let chunks = [];
for (const message of prompt) {
switch (message.role) {
case "system":
chunks.push(message.content);
break;
case "assistant":
chunks.push(message.content + "</s>");
break;
case "user":
chunks.push("[INST] " + message.content + " [/INST]");
break;
}
}
return getTextTokenCount(chunks.join(" "));
}
function getTextTokenCount(prompt: string) {
// Don't try tokenizing if the prompt is massive to prevent DoS.
// 500k characters should be sufficient for all supported models.
if (prompt.length > 500000) {
return {
tokenizer: "length fallback",
token_count: 100000,
};
}
return {
tokenizer: "mistral-tokenizer-js",
token_count: tokenizer.encode(prompt.normalize("NFKC"))!.length,
};
}
+29 -4
View File
@@ -2,7 +2,11 @@ import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json"; import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../logger"; import { logger } from "../../logger";
import { libSharp } from "../file-storage"; import { libSharp } from "../file-storage";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload"; import type {
GoogleAIChatMessage,
OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { z } from "zod";
const log = logger.child({ module: "tokenizer", service: "openai" }); const log = logger.child({ module: "tokenizer", service: "openai" });
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170; const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
@@ -29,11 +33,11 @@ export async function getTokenCount(
return getTextTokenCount(prompt); return getTextTokenCount(prompt);
} }
const gpt4 = model.startsWith("gpt-4"); const oldFormatting = model.startsWith("turbo-0301");
const vision = model.includes("vision"); const vision = model.includes("vision");
const tokensPerMessage = gpt4 ? 3 : 4; const tokensPerMessage = oldFormatting ? 4 : 3;
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0; let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
@@ -228,3 +232,24 @@ export function getOpenAIImageCost(params: {
token_count: Math.ceil(tokens), token_count: Math.ceil(tokens),
}; };
} }
export function estimateGoogleAITokenCount(prompt: string | GoogleAIChatMessage[]) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
const tokensPerMessage = 3;
let numTokens = 0;
for (const message of prompt) {
numTokens += tokensPerMessage;
numTokens += encoder.encode(message.parts[0].text).length;
}
numTokens += 3;
return {
tokenizer: "tiktoken (google-ai estimate)",
token_count: numTokens,
};
}
+22 -5
View File
@@ -1,5 +1,9 @@
import { Request } from "express"; import { Request } from "express";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload"; import type {
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { assertNever } from "../utils"; import { assertNever } from "../utils";
import { import {
init as initClaude, init as initClaude,
@@ -9,12 +13,18 @@ import {
init as initOpenAi, init as initOpenAi,
getTokenCount as getOpenAITokenCount, getTokenCount as getOpenAITokenCount,
getOpenAIImageCost, getOpenAIImageCost,
estimateGoogleAITokenCount,
} from "./openai"; } from "./openai";
import {
init as initMistralAI,
getTokenCount as getMistralAITokenCount,
} from "./mistral";
import { APIFormat } from "../key-management"; import { APIFormat } from "../key-management";
export async function init() { export async function init() {
initClaude(); initClaude();
initOpenAi(); initOpenAi();
initMistralAI();
} }
/** Tagged union via `service` field of the different types of requests that can /** Tagged union via `service` field of the different types of requests that can
@@ -24,8 +34,10 @@ type TokenCountRequest = { req: Request } & (
| { | {
prompt: string; prompt: string;
completion?: never; completion?: never;
service: "openai-text" | "anthropic" | "google-palm"; service: "openai-text" | "anthropic" | "google-ai";
} }
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
| { prompt: MistralAIChatMessage[]; completion?: never; service: "mistral-ai" }
| { prompt?: never; completion: string; service: APIFormat } | { prompt?: never; completion: string; service: APIFormat }
| { prompt?: never; completion?: never; service: "openai-image" } | { prompt?: never; completion?: never; service: "openai-image" }
); );
@@ -65,11 +77,16 @@ export async function countTokens({
}), }),
tokenization_duration_ms: getElapsedMs(time), tokenization_duration_ms: getElapsedMs(time),
}; };
case "google-palm": case "google-ai":
// TODO: Can't find a tokenization library for PaLM. There is an API // TODO: Can't find a tokenization library for Gemini. There is an API
// endpoint for it but it adds significant latency to the request. // endpoint for it but it adds significant latency to the request.
return { return {
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)), ...estimateGoogleAITokenCount(prompt ?? (completion || [])),
tokenization_duration_ms: getElapsedMs(time),
};
case "mistral-ai":
return {
...getMistralAITokenCount(prompt ?? completion),
tokenization_duration_ms: getElapsedMs(time), tokenization_duration_ms: getElapsedMs(time),
}; };
default: default:
+1 -1
View File
@@ -9,7 +9,7 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
"gpt4-turbo": z.number().optional().default(0), "gpt4-turbo": z.number().optional().default(0),
"dall-e": z.number().optional().default(0), "dall-e": z.number().optional().default(0),
claude: z.number().optional().default(0), claude: z.number().optional().default(0),
bison: z.number().optional().default(0), "gemini-pro": z.number().optional().default(0),
"aws-claude": z.number().optional().default(0), "aws-claude": z.number().optional().default(0),
}); });
+10 -4
View File
@@ -14,7 +14,8 @@ import { config, getFirebaseApp } from "../../config";
import { import {
getAzureOpenAIModelFamily, getAzureOpenAIModelFamily,
getClaudeModelFamily, getClaudeModelFamily,
getGooglePalmModelFamily, getGoogleAIModelFamily,
getMistralAIModelFamily,
getOpenAIModelFamily, getOpenAIModelFamily,
MODEL_FAMILIES, MODEL_FAMILIES,
ModelFamily, ModelFamily,
@@ -33,7 +34,10 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
"gpt4-turbo": 0, "gpt4-turbo": 0,
"dall-e": 0, "dall-e": 0,
claude: 0, claude: 0,
bison: 0, "gemini-pro": 0,
"mistral-tiny": 0,
"mistral-small": 0,
"mistral-medium": 0,
"aws-claude": 0, "aws-claude": 0,
"azure-turbo": 0, "azure-turbo": 0,
"azure-gpt4": 0, "azure-gpt4": 0,
@@ -397,8 +401,10 @@ function getModelFamilyForQuotaUsage(
return getOpenAIModelFamily(model); return getOpenAIModelFamily(model);
case "anthropic": case "anthropic":
return getClaudeModelFamily(model); return getClaudeModelFamily(model);
case "google-palm": case "google-ai":
return getGooglePalmModelFamily(model); return getGoogleAIModelFamily(model);
case "mistral-ai":
return getMistralAIModelFamily(model);
default: default:
assertNever(api); assertNever(api);
} }
+1 -1
View File
@@ -15,5 +15,5 @@
}, },
"include": ["src"], "include": ["src"],
"exclude": ["node_modules"], "exclude": ["node_modules"],
"files": ["src/types/custom.d.ts"] "files": ["src/shared/custom.d.ts"]
} }