Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 49aabddd71 | |||
| 7b0892ddae | |||
| 7f92565739 | |||
| 936d3c0721 | |||
| 4ffa7fb12b | |||
| 8dc7464381 | |||
| d2cd24bfd2 | |||
| e33f778192 | |||
| 4a823b216f | |||
| 01e76cbb1c | |||
| 655703e680 | |||
| 3be2687793 | |||
| 5599a83ae4 | |||
| de34d41918 | |||
| c5cd90dcef | |||
| 8a135a960d | |||
| 707cbbce16 | |||
| fad16cc268 | |||
| 0d3682197c | |||
| e0624e30fd |
+3
-3
@@ -34,10 +34,10 @@
|
||||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | mistral-tiny | mistral-small | mistral-medium | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
||||
# generation, uncomment the line below and add 'dall-e' to the list.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,mistral-tiny,mistral-small,mistral-medium,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
|
||||
# URLs from which requests will be blocked.
|
||||
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
||||
@@ -95,7 +95,7 @@
|
||||
# TOKEN_QUOTA_GPT4_TURBO=0
|
||||
# TOKEN_QUOTA_DALL_E=0
|
||||
# TOKEN_QUOTA_CLAUDE=0
|
||||
# TOKEN_QUOTA_BISON=0
|
||||
# TOKEN_QUOTA_GEMINI_PRO=0
|
||||
# TOKEN_QUOTA_AWS_CLAUDE=0
|
||||
|
||||
# How often to refresh token quotas. (hourly | daily)
|
||||
|
||||
@@ -10,4 +10,6 @@ COPY Dockerfile greeting.md* .env* ./
|
||||
RUN npm run build
|
||||
EXPOSE 7860
|
||||
ENV NODE_ENV=production
|
||||
# Huggigface free VMs have 16GB of RAM so we can be greedy
|
||||
ENV NODE_OPTIONS="--max-old-space-size=12882"
|
||||
CMD [ "npm", "start" ]
|
||||
|
||||
@@ -35,7 +35,7 @@ Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL
|
||||
ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
|
||||
|
||||
# All models as of this writing
|
||||
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e
|
||||
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,dall-e
|
||||
```
|
||||
|
||||
Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
|
||||
|
||||
@@ -32,6 +32,7 @@ COPY Dockerfile greeting.md* .env* ./
|
||||
RUN npm run build
|
||||
EXPOSE 7860
|
||||
ENV NODE_ENV=production
|
||||
ENV NODE_OPTIONS="--max-old-space-size=12882"
|
||||
CMD [ "npm", "start" ]
|
||||
```
|
||||
- Click "Commit new file to `main`" to save the Dockerfile.
|
||||
|
||||
Generated
+34
@@ -36,6 +36,7 @@
|
||||
"sanitize-html": "^2.11.0",
|
||||
"sharp": "^0.32.6",
|
||||
"showdown": "^2.1.0",
|
||||
"stream-json": "^1.8.0",
|
||||
"tiktoken": "^1.0.10",
|
||||
"uuid": "^9.0.0",
|
||||
"zlib": "^1.0.5",
|
||||
@@ -51,6 +52,7 @@
|
||||
"@types/node-schedule": "^2.1.0",
|
||||
"@types/sanitize-html": "^2.9.0",
|
||||
"@types/showdown": "^2.0.0",
|
||||
"@types/stream-json": "^1.7.7",
|
||||
"@types/uuid": "^9.0.1",
|
||||
"concurrently": "^8.0.1",
|
||||
"esbuild": "^0.17.16",
|
||||
@@ -1185,6 +1187,25 @@
|
||||
"integrity": "sha512-70xBJoLv+oXjB5PhtA8vo7erjLDp9/qqI63SRHm4REKrwuPOLs8HhXwlZJBJaB4kC18cCZ1UUZ6Fb/PLFW4TCA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/stream-chain": {
|
||||
"version": "2.0.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/stream-chain/-/stream-chain-2.0.4.tgz",
|
||||
"integrity": "sha512-V7TsWLHrx79KumkHqSD7F8eR6POpEuWb6PuXJ7s/dRHAf3uVst3Jkp1yZ5XqIfECZLQ4a28vBVstTErmsMBvaQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/stream-json": {
|
||||
"version": "1.7.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/stream-json/-/stream-json-1.7.7.tgz",
|
||||
"integrity": "sha512-hHG7cLQ09H/m9i0jzL6UJAeLLxIWej90ECn0svO4T8J0nGcl89xZDQ2ujT4WKlvg0GWkcxJbjIDzW/v7BYUM6Q==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/node": "*",
|
||||
"@types/stream-chain": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/uuid": {
|
||||
"version": "9.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.1.tgz",
|
||||
@@ -5135,6 +5156,11 @@
|
||||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/stream-chain": {
|
||||
"version": "2.2.5",
|
||||
"resolved": "https://registry.npmjs.org/stream-chain/-/stream-chain-2.2.5.tgz",
|
||||
"integrity": "sha512-1TJmBx6aSWqZ4tx7aTpBDXK0/e2hhcNSTV8+CbFJtDjbb+I1mZ8lHit0Grw9GRT+6JbIrrDd8esncgBi8aBXGA=="
|
||||
},
|
||||
"node_modules/stream-events": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz",
|
||||
@@ -5144,6 +5170,14 @@
|
||||
"stubs": "^3.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/stream-json": {
|
||||
"version": "1.8.0",
|
||||
"resolved": "https://registry.npmjs.org/stream-json/-/stream-json-1.8.0.tgz",
|
||||
"integrity": "sha512-HZfXngYHUAr1exT4fxlbc1IOce1RYxp2ldeaf97LYCOPSoOqY/1Psp7iGvpb+6JIOgkra9zDYnPX01hGAHzEPw==",
|
||||
"dependencies": {
|
||||
"stream-chain": "^2.2.5"
|
||||
}
|
||||
},
|
||||
"node_modules/stream-shift": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.1.tgz",
|
||||
|
||||
@@ -44,6 +44,7 @@
|
||||
"sanitize-html": "^2.11.0",
|
||||
"sharp": "^0.32.6",
|
||||
"showdown": "^2.1.0",
|
||||
"stream-json": "^1.8.0",
|
||||
"tiktoken": "^1.0.10",
|
||||
"uuid": "^9.0.0",
|
||||
"zlib": "^1.0.5",
|
||||
@@ -59,6 +60,7 @@
|
||||
"@types/node-schedule": "^2.1.0",
|
||||
"@types/sanitize-html": "^2.9.0",
|
||||
"@types/showdown": "^2.0.0",
|
||||
"@types/stream-json": "^1.7.7",
|
||||
"@types/uuid": "^9.0.1",
|
||||
"concurrently": "^8.0.1",
|
||||
"esbuild": "^0.17.16",
|
||||
|
||||
@@ -81,7 +81,7 @@ Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "gpt-4-1106-preview",
|
||||
"max_tokens": 20,
|
||||
"stream": true,
|
||||
"temperature": 1,
|
||||
@@ -231,8 +231,36 @@ Content-Type: application/json
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation
|
||||
POST {{proxy-host}}/proxy/google-palm/v1/chat/completions
|
||||
# @name Proxy / Azure OpenAI -- Native Chat Completions
|
||||
POST {{proxy-host}}/proxy/azure/openai/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 20,
|
||||
"stream": true,
|
||||
"temperature": 1,
|
||||
"seed": 2,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi what is the name of the fourth president of the united states?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "That would be George Washington."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "That's not right."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Google AI -- OpenAI-to-Google AI API Translation
|
||||
POST {{proxy-host}}/proxy/google-ai/v1/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const axios = require("axios");
|
||||
|
||||
const concurrentRequests = 5;
|
||||
const concurrentRequests = 75;
|
||||
const headers = {
|
||||
Authorization: "Bearer test",
|
||||
"Content-Type": "application/json",
|
||||
@@ -16,7 +16,7 @@ const payload = {
|
||||
const makeRequest = async (i) => {
|
||||
try {
|
||||
const response = await axios.post(
|
||||
"http://localhost:7860/proxy/azure/openai/v1/chat/completions",
|
||||
"http://localhost:7860/proxy/google-ai/v1/chat/completions",
|
||||
payload,
|
||||
{ headers }
|
||||
);
|
||||
@@ -25,7 +25,8 @@ const makeRequest = async (i) => {
|
||||
response.data
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(`Error in req ${i}:`, error.message);
|
||||
const msg = error.response
|
||||
console.error(`Error in req ${i}:`, error.message, msg || "");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
+3
-2
@@ -4,7 +4,8 @@ import { HttpError } from "../shared/errors";
|
||||
import { injectLocals } from "../shared/inject-locals";
|
||||
import { withSession } from "../shared/with-session";
|
||||
import { injectCsrfToken, checkCsrfToken } from "../shared/inject-csrf";
|
||||
import { buildInfoPageHtml } from "../info-page";
|
||||
import { renderPage } from "../info-page";
|
||||
import { buildInfo } from "../service-info";
|
||||
import { loginRouter } from "./login";
|
||||
import { usersApiRouter as apiRouter } from "./api/users";
|
||||
import { usersWebRouter as webRouter } from "./web/manage";
|
||||
@@ -26,7 +27,7 @@ adminRouter.use("/", loginRouter);
|
||||
adminRouter.use("/manage", authorize({ via: "cookie" }), webRouter);
|
||||
adminRouter.use("/service-info", authorize({ via: "cookie" }), (req, res) => {
|
||||
return res.send(
|
||||
buildInfoPageHtml(req.protocol + "://" + req.get("host"), true)
|
||||
renderPage(buildInfo(req.protocol + "://" + req.get("host"), true))
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -200,7 +200,7 @@ router.post("/maintenance", (req, res) => {
|
||||
keyPool.recheck("anthropic");
|
||||
const size = keyPool
|
||||
.list()
|
||||
.filter((k) => k.service !== "google-palm").length;
|
||||
.filter((k) => k.service !== "google-ai").length;
|
||||
flash.type = "success";
|
||||
flash.message = `Scheduled recheck of ${size} keys for OpenAI and Anthropic.`;
|
||||
break;
|
||||
|
||||
+41
-15
@@ -4,6 +4,7 @@ import path from "path";
|
||||
import pino from "pino";
|
||||
import type { ModelFamily } from "./shared/models";
|
||||
import { MODEL_FAMILIES } from "./shared/models";
|
||||
|
||||
dotenv.config();
|
||||
|
||||
const startupLogger = pino({ level: "debug" }).child({ module: "startup" });
|
||||
@@ -19,8 +20,16 @@ type Config = {
|
||||
openaiKey?: string;
|
||||
/** Comma-delimited list of Anthropic API keys. */
|
||||
anthropicKey?: string;
|
||||
/** Comma-delimited list of Google PaLM API keys. */
|
||||
googlePalmKey?: string;
|
||||
/**
|
||||
* Comma-delimited list of Google AI API keys. Note that these are not the
|
||||
* same as the GCP keys/credentials used for Vertex AI; the models are the
|
||||
* same but the APIs are different. Vertex is the GCP product for enterprise.
|
||||
**/
|
||||
googleAIKey?: string;
|
||||
/**
|
||||
* Comma-delimited list of Mistral AI API keys.
|
||||
*/
|
||||
mistralAIKey?: string;
|
||||
/**
|
||||
* Comma-delimited list of AWS credentials. Each credential item should be a
|
||||
* colon-delimited list of access key, secret key, and AWS region.
|
||||
@@ -197,7 +206,8 @@ export const config: Config = {
|
||||
port: getEnvWithDefault("PORT", 7860),
|
||||
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
|
||||
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
||||
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
|
||||
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
||||
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
|
||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
@@ -229,7 +239,10 @@ export const config: Config = {
|
||||
"gpt4-32k",
|
||||
"gpt4-turbo",
|
||||
"claude",
|
||||
"bison",
|
||||
"gemini-pro",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
@@ -361,12 +374,13 @@ export const SENSITIVE_KEYS: (keyof Config)[] = ["googleSheetsSpreadsheetId"];
|
||||
* Config keys that are not displayed on the info page at all, generally because
|
||||
* they are not relevant to the user or can be inferred from other config.
|
||||
*/
|
||||
export const OMITTED_KEYS: (keyof Config)[] = [
|
||||
export const OMITTED_KEYS = [
|
||||
"port",
|
||||
"logLevel",
|
||||
"openaiKey",
|
||||
"anthropicKey",
|
||||
"googlePalmKey",
|
||||
"googleAIKey",
|
||||
"mistralAIKey",
|
||||
"awsCredentials",
|
||||
"azureCredentials",
|
||||
"proxyKey",
|
||||
@@ -387,34 +401,46 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
||||
"staticServiceInfo",
|
||||
"checkKeys",
|
||||
"allowedModelFamilies",
|
||||
];
|
||||
] satisfies (keyof Config)[];
|
||||
type OmitKeys = (typeof OMITTED_KEYS)[number];
|
||||
|
||||
type Printable<T> = {
|
||||
[P in keyof T as Exclude<P, OmitKeys>]: T[P] extends object
|
||||
? Printable<T[P]>
|
||||
: string;
|
||||
};
|
||||
type PublicConfig = Printable<Config>;
|
||||
|
||||
const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>;
|
||||
|
||||
export function listConfig(obj: Config = config): Record<string, any> {
|
||||
const result: Record<string, any> = {};
|
||||
export function listConfig(obj: Config = config) {
|
||||
const result: Record<string, unknown> = {};
|
||||
for (const key of getKeys(obj)) {
|
||||
const value = obj[key]?.toString() || "";
|
||||
|
||||
const shouldOmit =
|
||||
OMITTED_KEYS.includes(key) || value === "" || value === "undefined";
|
||||
const shouldMask = SENSITIVE_KEYS.includes(key);
|
||||
const shouldOmit =
|
||||
OMITTED_KEYS.includes(key as OmitKeys) ||
|
||||
value === "" ||
|
||||
value === "undefined";
|
||||
|
||||
if (shouldOmit) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const validKey = key as keyof Printable<Config>;
|
||||
|
||||
if (value && shouldMask) {
|
||||
result[key] = "********";
|
||||
result[validKey] = "********";
|
||||
} else {
|
||||
result[key] = value;
|
||||
result[validKey] = value;
|
||||
}
|
||||
|
||||
if (typeof obj[key] === "object" && !Array.isArray(obj[key])) {
|
||||
result[key] = listConfig(obj[key] as unknown as Config);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
return result as PublicConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -433,7 +459,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
||||
[
|
||||
"OPENAI_KEY",
|
||||
"ANTHROPIC_KEY",
|
||||
"GOOGLE_PALM_KEY",
|
||||
"GOOGLE_AI_KEY",
|
||||
"AWS_CREDENTIALS",
|
||||
"AZURE_CREDENTIALS",
|
||||
].includes(String(env))
|
||||
|
||||
+42
-500
@@ -1,74 +1,39 @@
|
||||
/** This whole module really sucks */
|
||||
/** This whole module kinda sucks */
|
||||
import fs from "fs";
|
||||
import { Request, Response } from "express";
|
||||
import showdown from "showdown";
|
||||
import { config, listConfig } from "./config";
|
||||
import {
|
||||
AnthropicKey,
|
||||
AwsBedrockKey,
|
||||
AzureOpenAIKey,
|
||||
GooglePalmKey,
|
||||
keyPool,
|
||||
OpenAIKey,
|
||||
} from "./shared/key-management";
|
||||
import {
|
||||
AzureOpenAIModelFamily,
|
||||
ModelFamily,
|
||||
OpenAIModelFamily,
|
||||
} from "./shared/models";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
|
||||
import { assertNever } from "./shared/utils";
|
||||
import { config } from "./config";
|
||||
import { buildInfo, ServiceInfo } from "./service-info";
|
||||
import { getLastNImages } from "./shared/file-storage/image-history";
|
||||
import { keyPool } from "./shared/key-management";
|
||||
import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models";
|
||||
|
||||
const INFO_PAGE_TTL = 2000;
|
||||
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
"turbo": "GPT-3.5 Turbo",
|
||||
"gpt4": "GPT-4",
|
||||
"gpt4-32k": "GPT-4 32k",
|
||||
"gpt4-turbo": "GPT-4 Turbo",
|
||||
"dall-e": "DALL-E",
|
||||
"claude": "Claude",
|
||||
"gemini-pro": "Gemini Pro",
|
||||
"mistral-tiny": "Mistral 7B",
|
||||
"mistral-small": "Mixtral 8x7B",
|
||||
"mistral-medium": "Mistral Medium (prototype)",
|
||||
"aws-claude": "AWS Claude",
|
||||
"azure-turbo": "Azure GPT-3.5 Turbo",
|
||||
"azure-gpt4": "Azure GPT-4",
|
||||
"azure-gpt4-32k": "Azure GPT-4 32k",
|
||||
"azure-gpt4-turbo": "Azure GPT-4 Turbo",
|
||||
};
|
||||
|
||||
const converter = new showdown.Converter();
|
||||
const customGreeting = fs.existsSync("greeting.md")
|
||||
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
|
||||
: "";
|
||||
let infoPageHtml: string | undefined;
|
||||
let infoPageLastUpdated = 0;
|
||||
|
||||
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
|
||||
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
|
||||
k.service === "openai";
|
||||
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
|
||||
k.service === "azure";
|
||||
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||
k.service === "anthropic";
|
||||
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
|
||||
k.service === "google-palm";
|
||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||
|
||||
type ModelAggregates = {
|
||||
active: number;
|
||||
trial?: number;
|
||||
revoked?: number;
|
||||
overQuota?: number;
|
||||
pozzed?: number;
|
||||
awsLogged?: number;
|
||||
queued: number;
|
||||
queueTime: string;
|
||||
tokens: number;
|
||||
};
|
||||
type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`;
|
||||
type ServiceAggregates = {
|
||||
status?: string;
|
||||
openaiKeys?: number;
|
||||
openaiOrgs?: number;
|
||||
anthropicKeys?: number;
|
||||
palmKeys?: number;
|
||||
awsKeys?: number;
|
||||
azureKeys?: number;
|
||||
proompts: number;
|
||||
tokens: number;
|
||||
tokenCost: number;
|
||||
openAiUncheckedKeys?: number;
|
||||
anthropicUncheckedKeys?: number;
|
||||
} & {
|
||||
[modelFamily in ModelFamily]?: ModelAggregates;
|
||||
};
|
||||
|
||||
const modelStats = new Map<ModelAggregateKey, number>();
|
||||
const serviceStats = new Map<keyof ServiceAggregates, number>();
|
||||
|
||||
export const handleInfoPage = (req: Request, res: Response) => {
|
||||
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
|
||||
return res.send(infoPageHtml);
|
||||
@@ -79,87 +44,16 @@ export const handleInfoPage = (req: Request, res: Response) => {
|
||||
? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID)
|
||||
: req.protocol + "://" + req.get("host");
|
||||
|
||||
infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy");
|
||||
const info = buildInfo(baseUrl + "/proxy");
|
||||
infoPageHtml = renderPage(info);
|
||||
infoPageLastUpdated = Date.now();
|
||||
|
||||
res.send(infoPageHtml);
|
||||
};
|
||||
|
||||
function getCostString(cost: number) {
|
||||
if (!config.showTokenCosts) return "";
|
||||
return ` ($${cost.toFixed(2)})`;
|
||||
}
|
||||
|
||||
export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
const keys = keyPool.list();
|
||||
const hideFullInfo = config.staticServiceInfo && !asAdmin;
|
||||
|
||||
modelStats.clear();
|
||||
serviceStats.clear();
|
||||
keys.forEach(addKeyToAggregates);
|
||||
|
||||
const openaiKeys = serviceStats.get("openaiKeys") || 0;
|
||||
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
|
||||
const palmKeys = serviceStats.get("palmKeys") || 0;
|
||||
const awsKeys = serviceStats.get("awsKeys") || 0;
|
||||
const azureKeys = serviceStats.get("azureKeys") || 0;
|
||||
const proompts = serviceStats.get("proompts") || 0;
|
||||
const tokens = serviceStats.get("tokens") || 0;
|
||||
const tokenCost = serviceStats.get("tokenCost") || 0;
|
||||
|
||||
const allowDalle = config.allowedModelFamilies.includes("dall-e");
|
||||
|
||||
const endpoints = {
|
||||
...(openaiKeys ? { openai: baseUrl + "/openai" } : {}),
|
||||
...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}),
|
||||
...(openaiKeys && allowDalle
|
||||
? { ["openai-image"]: baseUrl + "/openai-image" }
|
||||
: {}),
|
||||
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
|
||||
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
|
||||
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
|
||||
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
|
||||
};
|
||||
|
||||
const stats = {
|
||||
proompts,
|
||||
tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`,
|
||||
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
||||
};
|
||||
|
||||
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
|
||||
for (const key of Object.keys(keyInfo)) {
|
||||
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
|
||||
}
|
||||
|
||||
const providerInfo = {
|
||||
...(openaiKeys ? getOpenAIInfo() : {}),
|
||||
...(anthropicKeys ? getAnthropicInfo() : {}),
|
||||
...(palmKeys ? getPalmInfo() : {}),
|
||||
...(awsKeys ? getAwsInfo() : {}),
|
||||
...(azureKeys ? getAzureInfo() : {}),
|
||||
};
|
||||
|
||||
if (hideFullInfo) {
|
||||
for (const provider of Object.keys(providerInfo)) {
|
||||
delete (providerInfo as any)[provider].proomptersInQueue;
|
||||
delete (providerInfo as any)[provider].estimatedQueueTime;
|
||||
delete (providerInfo as any)[provider].usage;
|
||||
}
|
||||
}
|
||||
|
||||
const info = {
|
||||
uptime: Math.floor(process.uptime()),
|
||||
endpoints,
|
||||
...(hideFullInfo ? {} : stats),
|
||||
...keyInfo,
|
||||
...providerInfo,
|
||||
config: listConfig(),
|
||||
build: process.env.BUILD_INFO || "dev",
|
||||
};
|
||||
|
||||
export function renderPage(info: ServiceInfo) {
|
||||
const title = getServerTitle();
|
||||
const headerHtml = buildInfoPageHeader(new showdown.Converter(), title);
|
||||
const headerHtml = buildInfoPageHeader(info);
|
||||
|
||||
return `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
@@ -178,324 +72,14 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
</html>`;
|
||||
}
|
||||
|
||||
function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) {
|
||||
const orgIds = new Set(
|
||||
keys.filter((k) => k.service === "openai").map((k: any) => k.organizationId)
|
||||
);
|
||||
return orgIds.size;
|
||||
}
|
||||
|
||||
function increment<T extends keyof ServiceAggregates | ModelAggregateKey>(
|
||||
map: Map<T, number>,
|
||||
key: T,
|
||||
delta = 1
|
||||
) {
|
||||
map.set(key, (map.get(key) || 0) + delta);
|
||||
}
|
||||
|
||||
function addKeyToAggregates(k: KeyPoolKey) {
|
||||
increment(serviceStats, "proompts", k.promptCount);
|
||||
increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0);
|
||||
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
|
||||
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
|
||||
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
let sumTokens = 0;
|
||||
let sumCost = 0;
|
||||
|
||||
switch (k.service) {
|
||||
case "openai":
|
||||
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
|
||||
increment(
|
||||
serviceStats,
|
||||
"openAiUncheckedKeys",
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "azure":
|
||||
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "anthropic": {
|
||||
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
|
||||
const family = "claude";
|
||||
sumTokens += k.claudeTokens;
|
||||
sumCost += getTokenCostUsd(family, k.claudeTokens);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k.claudeTokens);
|
||||
increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0);
|
||||
increment(
|
||||
serviceStats,
|
||||
"anthropicUncheckedKeys",
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
break;
|
||||
}
|
||||
case "google-palm": {
|
||||
if (!keyIsGooglePalmKey(k)) throw new Error("Invalid key type");
|
||||
const family = "bison";
|
||||
sumTokens += k.bisonTokens;
|
||||
sumCost += getTokenCostUsd(family, k.bisonTokens);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k.bisonTokens);
|
||||
break;
|
||||
}
|
||||
case "aws": {
|
||||
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
|
||||
const family = "aws-claude";
|
||||
sumTokens += k["aws-claudeTokens"];
|
||||
sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
|
||||
|
||||
// Ignore revoked keys for aws logging stats, but include keys where the
|
||||
// logging status is unknown.
|
||||
const countAsLogged =
|
||||
k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled";
|
||||
increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0);
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assertNever(k.service);
|
||||
}
|
||||
|
||||
increment(serviceStats, "tokens", sumTokens);
|
||||
increment(serviceStats, "tokenCost", sumCost);
|
||||
}
|
||||
|
||||
function getOpenAIInfo() {
|
||||
const info: { status?: string; openaiKeys?: number; openaiOrgs?: number } & {
|
||||
[modelFamily in OpenAIModelFamily]?: {
|
||||
usage?: string;
|
||||
activeKeys: number;
|
||||
trialKeys?: number;
|
||||
revokedKeys?: number;
|
||||
overQuotaKeys?: number;
|
||||
proomptersInQueue?: number;
|
||||
estimatedQueueTime?: string;
|
||||
};
|
||||
} = {};
|
||||
|
||||
const keys = keyPool.list().filter(keyIsOpenAIKey);
|
||||
const enabledFamilies = new Set(config.allowedModelFamilies);
|
||||
const accessibleFamilies = keys
|
||||
.flatMap((k) => k.modelFamilies)
|
||||
.filter((f) => enabledFamilies.has(f))
|
||||
.concat("turbo");
|
||||
const familySet = new Set(accessibleFamilies);
|
||||
|
||||
if (config.checkKeys) {
|
||||
const unchecked = serviceStats.get("openAiUncheckedKeys") || 0;
|
||||
if (unchecked > 0) {
|
||||
info.status = `Checking ${unchecked} keys...`;
|
||||
}
|
||||
info.openaiKeys = keys.length;
|
||||
info.openaiOrgs = getUniqueOpenAIOrgs(keys);
|
||||
|
||||
familySet.forEach((f) => {
|
||||
const tokens = modelStats.get(`${f}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(f, tokens);
|
||||
|
||||
info[f] = {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: modelStats.get(`${f}__active`) || 0,
|
||||
trialKeys: modelStats.get(`${f}__trial`) || 0,
|
||||
revokedKeys: modelStats.get(`${f}__revoked`) || 0,
|
||||
overQuotaKeys: modelStats.get(`${f}__overQuota`) || 0,
|
||||
};
|
||||
|
||||
// Don't show trial/revoked keys for non-turbo families.
|
||||
// Generally those stats only make sense for the lowest-tier model.
|
||||
if (f !== "turbo") {
|
||||
delete info[f]!.trialKeys;
|
||||
delete info[f]!.revokedKeys;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
info.status = "Key checking is disabled.";
|
||||
info.turbo = { activeKeys: keys.filter((k) => !k.isDisabled).length };
|
||||
info.gpt4 = {
|
||||
activeKeys: keys.filter(
|
||||
(k) => !k.isDisabled && k.modelFamilies.includes("gpt4")
|
||||
).length,
|
||||
};
|
||||
}
|
||||
|
||||
familySet.forEach((f) => {
|
||||
if (enabledFamilies.has(f)) {
|
||||
if (!info[f]) info[f] = { activeKeys: 0 }; // may occur if checkKeys is disabled
|
||||
const { estimatedQueueTime, proomptersInQueue } = getQueueInformation(f);
|
||||
info[f]!.proomptersInQueue = proomptersInQueue;
|
||||
info[f]!.estimatedQueueTime = estimatedQueueTime;
|
||||
} else {
|
||||
(info[f]! as any).status = "GPT-3.5-Turbo is disabled on this proxy.";
|
||||
}
|
||||
});
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
function getAnthropicInfo() {
|
||||
const claudeInfo: Partial<ModelAggregates> = {
|
||||
active: modelStats.get("claude__active") || 0,
|
||||
pozzed: modelStats.get("claude__pozzed") || 0,
|
||||
revoked: modelStats.get("claude__revoked") || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation("claude");
|
||||
claudeInfo.queued = queue.proomptersInQueue;
|
||||
claudeInfo.queueTime = queue.estimatedQueueTime;
|
||||
|
||||
const tokens = modelStats.get("claude__tokens") || 0;
|
||||
const cost = getTokenCostUsd("claude", tokens);
|
||||
|
||||
const unchecked =
|
||||
(config.checkKeys && serviceStats.get("anthropicUncheckedKeys")) || 0;
|
||||
|
||||
return {
|
||||
claude: {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
...(unchecked > 0 ? { status: `Checking ${unchecked} keys...` } : {}),
|
||||
activeKeys: claudeInfo.active,
|
||||
revokedKeys: claudeInfo.revoked,
|
||||
...(config.checkKeys ? { pozzedKeys: claudeInfo.pozzed } : {}),
|
||||
proomptersInQueue: claudeInfo.queued,
|
||||
estimatedQueueTime: claudeInfo.queueTime,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function getPalmInfo() {
|
||||
const bisonInfo: Partial<ModelAggregates> = {
|
||||
active: modelStats.get("bison__active") || 0,
|
||||
revoked: modelStats.get("bison__revoked") || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation("bison");
|
||||
bisonInfo.queued = queue.proomptersInQueue;
|
||||
bisonInfo.queueTime = queue.estimatedQueueTime;
|
||||
|
||||
const tokens = modelStats.get("bison__tokens") || 0;
|
||||
const cost = getTokenCostUsd("bison", tokens);
|
||||
|
||||
return {
|
||||
bison: {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: bisonInfo.active,
|
||||
revokedKeys: bisonInfo.revoked,
|
||||
proomptersInQueue: bisonInfo.queued,
|
||||
estimatedQueueTime: bisonInfo.queueTime,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function getAwsInfo() {
|
||||
const awsInfo: Partial<ModelAggregates> = {
|
||||
active: modelStats.get("aws-claude__active") || 0,
|
||||
revoked: modelStats.get("aws-claude__revoked") || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation("aws-claude");
|
||||
awsInfo.queued = queue.proomptersInQueue;
|
||||
awsInfo.queueTime = queue.estimatedQueueTime;
|
||||
|
||||
const tokens = modelStats.get("aws-claude__tokens") || 0;
|
||||
const cost = getTokenCostUsd("aws-claude", tokens);
|
||||
|
||||
const logged = modelStats.get("aws-claude__awsLogged") || 0;
|
||||
const logMsg = config.allowAwsLogging
|
||||
? `${logged} active keys are potentially logged.`
|
||||
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
|
||||
|
||||
return {
|
||||
"aws-claude": {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: awsInfo.active,
|
||||
revokedKeys: awsInfo.revoked,
|
||||
proomptersInQueue: awsInfo.queued,
|
||||
estimatedQueueTime: awsInfo.queueTime,
|
||||
...(logged > 0 ? { privacy: logMsg } : {}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function getAzureInfo() {
|
||||
const azureFamilies = [
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4-32k",
|
||||
] as const;
|
||||
|
||||
const azureInfo: {
|
||||
[modelFamily in AzureOpenAIModelFamily]?: {
|
||||
usage?: string;
|
||||
activeKeys: number;
|
||||
revokedKeys?: number;
|
||||
proomptersInQueue?: number;
|
||||
estimatedQueueTime?: string;
|
||||
};
|
||||
} = {};
|
||||
for (const family of azureFamilies) {
|
||||
const familyAllowed = config.allowedModelFamilies.includes(family);
|
||||
const activeKeys = modelStats.get(`${family}__active`) || 0;
|
||||
|
||||
if (!familyAllowed || activeKeys === 0) continue;
|
||||
|
||||
azureInfo[family] = {
|
||||
activeKeys,
|
||||
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation(family);
|
||||
azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue;
|
||||
azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime;
|
||||
|
||||
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(family, tokens);
|
||||
azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString(
|
||||
cost
|
||||
)}`;
|
||||
}
|
||||
|
||||
return azureInfo;
|
||||
}
|
||||
|
||||
const customGreeting = fs.existsSync("greeting.md")
|
||||
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
|
||||
: "";
|
||||
|
||||
/**
|
||||
* If the server operator provides a `greeting.md` file, it will be included in
|
||||
* the rendered info page.
|
||||
**/
|
||||
function buildInfoPageHeader(converter: showdown.Converter, title: string) {
|
||||
function buildInfoPageHeader(info: ServiceInfo) {
|
||||
const title = getServerTitle();
|
||||
// TODO: use some templating engine instead of this mess
|
||||
let infoBody = `<!-- Header for Showdown's parser, don't remove this line -->
|
||||
# ${title}`;
|
||||
let infoBody = `# ${title}`;
|
||||
if (config.promptLogging) {
|
||||
infoBody += `\n## Prompt Logging Enabled
|
||||
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
|
||||
@@ -510,45 +94,18 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
|
||||
}
|
||||
|
||||
const waits: string[] = [];
|
||||
infoBody += `\n## Estimated Wait Times`;
|
||||
|
||||
if (config.openaiKey) {
|
||||
// TODO: un-fuck this
|
||||
const keys = keyPool.list().filter((k) => k.service === "openai");
|
||||
for (const modelFamily of config.allowedModelFamilies) {
|
||||
const service = MODEL_FAMILY_SERVICE[modelFamily];
|
||||
|
||||
const turboWait = getQueueInformation("turbo").estimatedQueueTime;
|
||||
waits.push(`**Turbo:** ${turboWait}`);
|
||||
const hasKeys = keyPool.list().some((k) => {
|
||||
return k.service === service && k.modelFamilies.includes(modelFamily);
|
||||
});
|
||||
|
||||
const gpt4Wait = getQueueInformation("gpt4").estimatedQueueTime;
|
||||
const hasGpt4 = keys.some((k) => k.modelFamilies.includes("gpt4"));
|
||||
const allowedGpt4 = config.allowedModelFamilies.includes("gpt4");
|
||||
if (hasGpt4 && allowedGpt4) {
|
||||
waits.push(`**GPT-4:** ${gpt4Wait}`);
|
||||
const wait = info[modelFamily]?.estimatedQueueTime;
|
||||
if (hasKeys && wait) {
|
||||
waits.push(`**${MODEL_FAMILY_FRIENDLY_NAME[modelFamily] || modelFamily}**: ${wait}`);
|
||||
}
|
||||
|
||||
const gpt432kWait = getQueueInformation("gpt4-32k").estimatedQueueTime;
|
||||
const hasGpt432k = keys.some((k) => k.modelFamilies.includes("gpt4-32k"));
|
||||
const allowedGpt432k = config.allowedModelFamilies.includes("gpt4-32k");
|
||||
if (hasGpt432k && allowedGpt432k) {
|
||||
waits.push(`**GPT-4-32k:** ${gpt432kWait}`);
|
||||
}
|
||||
|
||||
const dalleWait = getQueueInformation("dall-e").estimatedQueueTime;
|
||||
const hasDalle = keys.some((k) => k.modelFamilies.includes("dall-e"));
|
||||
const allowedDalle = config.allowedModelFamilies.includes("dall-e");
|
||||
if (hasDalle && allowedDalle) {
|
||||
waits.push(`**DALL-E:** ${dalleWait}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (config.anthropicKey) {
|
||||
const claudeWait = getQueueInformation("claude").estimatedQueueTime;
|
||||
waits.push(`**Claude:** ${claudeWait}`);
|
||||
}
|
||||
|
||||
if (config.awsCredentials) {
|
||||
const awsClaudeWait = getQueueInformation("aws-claude").estimatedQueueTime;
|
||||
waits.push(`**Claude (AWS):** ${awsClaudeWait}`);
|
||||
}
|
||||
|
||||
infoBody += "\n\n" + waits.join(" / ");
|
||||
@@ -565,21 +122,6 @@ function getSelfServiceLinks() {
|
||||
return `<footer style="font-size: 0.8em;"><hr /><a target="_blank" href="/user/lookup">Check your user token info</a></footer>`;
|
||||
}
|
||||
|
||||
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
|
||||
function getQueueInformation(partition: ModelFamily) {
|
||||
const waitMs = getEstimatedWaitTime(partition);
|
||||
const waitTime =
|
||||
waitMs < 60000
|
||||
? `${Math.round(waitMs / 1000)}sec`
|
||||
: `${Math.round(waitMs / 60000)}min, ${Math.round(
|
||||
(waitMs % 60000) / 1000
|
||||
)}sec`;
|
||||
return {
|
||||
proomptersInQueue: getQueueLength(partition),
|
||||
estimatedQueueTime: waitMs > 2000 ? waitTime : "no wait",
|
||||
};
|
||||
}
|
||||
|
||||
function getServerTitle() {
|
||||
// Use manually set title if available
|
||||
if (process.env.SERVER_TITLE) {
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
forceModel,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
if (!config.googleAIKey) return { object: "list", data: [] };
|
||||
|
||||
const googleAIVariants = ["gemini-pro"];
|
||||
|
||||
const models = googleAIVariants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "google",
|
||||
permission: [],
|
||||
root: "google",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const googleAIResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.inboundApi === "openai") {
|
||||
req.log.info("Transforming Google AI response to OpenAI format");
|
||||
body = transformGoogleAIResponse(body, req);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
function transformGoogleAIResponse(
|
||||
resBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
const parts = resBody.candidates[0].content?.parts ?? [{ text: "" }];
|
||||
const content = parts[0].text.replace(/^(.{0,50}?): /, () => "");
|
||||
return {
|
||||
id: "goo-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content },
|
||||
finish_reason: resBody.candidates[0].finishReason,
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const googleAIProxy = createQueueMiddleware({
|
||||
beforeProxy: addGoogleAIKey,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
const { protocol, hostname, path } = signedRequest;
|
||||
return `${protocol}//${hostname}${path}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const googleAIRouter = Router();
|
||||
googleAIRouter.get("/v1/models", handleModelRequest);
|
||||
// OpenAI-to-Google AI compatibility endpoint.
|
||||
googleAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
|
||||
{ afterTransform: [forceModel("gemini-pro")] }
|
||||
),
|
||||
googleAIProxy
|
||||
);
|
||||
|
||||
export const googleAI = googleAIRouter;
|
||||
@@ -2,7 +2,7 @@ import { Request, Response } from "express";
|
||||
import httpProxy from "http-proxy";
|
||||
import { ZodError } from "zod";
|
||||
import { generateErrorMessage } from "zod-error";
|
||||
import { buildFakeSse } from "../../shared/streaming";
|
||||
import { makeCompletionSSE } from "../../shared/streaming";
|
||||
import { assertNever } from "../../shared/utils";
|
||||
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
|
||||
|
||||
@@ -40,11 +40,13 @@ export function writeErrorResponse(
|
||||
req: Request,
|
||||
res: Response,
|
||||
statusCode: number,
|
||||
statusMessage: string,
|
||||
errorPayload: Record<string, any>
|
||||
) {
|
||||
const errorSource = errorPayload.error?.type?.startsWith("proxy")
|
||||
? "proxy"
|
||||
: "upstream";
|
||||
const msg =
|
||||
statusCode === 500
|
||||
? `The proxy encountered an error while trying to process your prompt.`
|
||||
: `The proxy encountered an error while trying to send your prompt to the upstream service.`;
|
||||
|
||||
// If we're mid-SSE stream, send a data event with the error payload and end
|
||||
// the stream. Otherwise just send a normal error response.
|
||||
@@ -52,10 +54,15 @@ export function writeErrorResponse(
|
||||
res.headersSent ||
|
||||
String(res.getHeader("content-type")).startsWith("text/event-stream")
|
||||
) {
|
||||
const errorTitle = `${errorSource} error (${statusCode})`;
|
||||
const errorContent = JSON.stringify(errorPayload, null, 2);
|
||||
const msg = buildFakeSse(errorTitle, errorContent, req);
|
||||
res.write(msg);
|
||||
const event = makeCompletionSSE({
|
||||
format: req.inboundApi,
|
||||
title: `Proxy error (HTTP ${statusCode} ${statusMessage})`,
|
||||
message: `${msg} Further technical details are provided below.`,
|
||||
obj: errorPayload,
|
||||
reqId: req.id,
|
||||
model: req.body?.model,
|
||||
});
|
||||
res.write(event);
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
@@ -77,8 +84,9 @@ export const classifyErrorAndSend = (
|
||||
res: Response
|
||||
) => {
|
||||
try {
|
||||
const { status, userMessage, ...errorDetails } = classifyError(err);
|
||||
writeErrorResponse(req, res, status, {
|
||||
const { statusCode, statusMessage, userMessage, ...errorDetails } =
|
||||
classifyError(err);
|
||||
writeErrorResponse(req, res, statusCode, statusMessage, {
|
||||
error: { message: userMessage, ...errorDetails },
|
||||
});
|
||||
} catch (error) {
|
||||
@@ -88,14 +96,17 @@ export const classifyErrorAndSend = (
|
||||
|
||||
function classifyError(err: Error): {
|
||||
/** HTTP status code returned to the client. */
|
||||
status: number;
|
||||
statusCode: number;
|
||||
/** HTTP status message returned to the client. */
|
||||
statusMessage: string;
|
||||
/** Message displayed to the user. */
|
||||
userMessage: string;
|
||||
/** Short error type, e.g. "proxy_validation_error". */
|
||||
type: string;
|
||||
} & Record<string, any> {
|
||||
const defaultError = {
|
||||
status: 500,
|
||||
statusCode: 500,
|
||||
statusMessage: "Internal Server Error",
|
||||
userMessage: `Reverse proxy error: ${err.message}`,
|
||||
type: "proxy_internal_error",
|
||||
stack: err.stack,
|
||||
@@ -112,19 +123,33 @@ function classifyError(err: Error): {
|
||||
return `At '${rest.pathComponent}': ${issue.message}`;
|
||||
},
|
||||
});
|
||||
return { status: 400, userMessage, type: "proxy_validation_error" };
|
||||
case "ForbiddenError":
|
||||
return {
|
||||
statusCode: 400,
|
||||
statusMessage: "Bad Request",
|
||||
userMessage,
|
||||
type: "proxy_validation_error",
|
||||
};
|
||||
case "ZoomerForbiddenError":
|
||||
// Mimics a ban notice from OpenAI, thrown when blockZoomerOrigins blocks
|
||||
// a request.
|
||||
return {
|
||||
status: 403,
|
||||
statusCode: 403,
|
||||
statusMessage: "Forbidden",
|
||||
userMessage: `Your account has been disabled for violating our terms of service.`,
|
||||
type: "organization_account_disabled",
|
||||
code: "policy_violation",
|
||||
};
|
||||
case "ForbiddenError":
|
||||
return {
|
||||
statusCode: 403,
|
||||
statusMessage: "Forbidden",
|
||||
userMessage: `Request is not allowed. (${err.message})`,
|
||||
type: "proxy_forbidden",
|
||||
};
|
||||
case "QuotaExceededError":
|
||||
return {
|
||||
status: 429,
|
||||
statusCode: 429,
|
||||
statusMessage: "Too Many Requests",
|
||||
userMessage: `You've exceeded your token quota for this model type.`,
|
||||
type: "proxy_quota_exceeded",
|
||||
info: (err as QuotaExceededError).quotaInfo,
|
||||
@@ -134,21 +159,24 @@ function classifyError(err: Error): {
|
||||
switch (err.code) {
|
||||
case "ENOTFOUND":
|
||||
return {
|
||||
status: 502,
|
||||
statusCode: 502,
|
||||
statusMessage: "Bad Gateway",
|
||||
userMessage: `Reverse proxy encountered a DNS error while trying to connect to the upstream service.`,
|
||||
type: "proxy_network_error",
|
||||
code: err.code,
|
||||
};
|
||||
case "ECONNREFUSED":
|
||||
return {
|
||||
status: 502,
|
||||
statusCode: 502,
|
||||
statusMessage: "Bad Gateway",
|
||||
userMessage: `Reverse proxy couldn't connect to the upstream service.`,
|
||||
type: "proxy_network_error",
|
||||
code: err.code,
|
||||
};
|
||||
case "ECONNRESET":
|
||||
return {
|
||||
status: 504,
|
||||
statusCode: 504,
|
||||
statusMessage: "Gateway Timeout",
|
||||
userMessage: `Reverse proxy timed out while waiting for the upstream service to respond.`,
|
||||
type: "proxy_network_error",
|
||||
code: err.code,
|
||||
@@ -165,6 +193,7 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
const format = req.outboundApi;
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return body.choices[0].message.content;
|
||||
case "openai-text":
|
||||
return body.choices[0].text;
|
||||
@@ -177,8 +206,11 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
return "";
|
||||
}
|
||||
return body.completion.trim();
|
||||
case "google-palm":
|
||||
return body.candidates[0].output;
|
||||
case "google-ai":
|
||||
if ("choices" in body) {
|
||||
return body.choices[0].message.content;
|
||||
}
|
||||
return body.candidates[0].content.parts[0].text;
|
||||
case "openai-image":
|
||||
return body.data?.map((item: any) => item.url).join("\n");
|
||||
default:
|
||||
@@ -191,13 +223,14 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
case "mistral-ai":
|
||||
return body.model;
|
||||
case "openai-image":
|
||||
return req.body.model;
|
||||
case "anthropic":
|
||||
// Anthropic confirms the model in the response, but AWS Claude doesn't.
|
||||
return body.model || req.body.model;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
// Google doesn't confirm the model in the response.
|
||||
return req.body.model;
|
||||
default:
|
||||
|
||||
@@ -29,7 +29,9 @@ export const createOnProxyReqHandler = ({
|
||||
// The streaming flag must be set before any other onProxyReq handler runs,
|
||||
// as it may influence the behavior of subsequent handlers.
|
||||
// Image generation requests can't be streamed.
|
||||
req.isStreaming = req.body.stream === true || req.body.stream === "true";
|
||||
// TODO: this flag is set in too many places
|
||||
req.isStreaming =
|
||||
req.isStreaming || req.body.stream === true || req.body.stream === "true";
|
||||
req.body.stream = req.isStreaming;
|
||||
|
||||
try {
|
||||
|
||||
@@ -31,10 +31,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
case "anthropic":
|
||||
assignedKey = keyPool.get("claude-v1");
|
||||
break;
|
||||
case "google-palm":
|
||||
assignedKey = keyPool.get("text-bison-001");
|
||||
delete req.body.stream;
|
||||
break;
|
||||
case "openai-text":
|
||||
assignedKey = keyPool.get("gpt-3.5-turbo-instruct");
|
||||
break;
|
||||
@@ -42,6 +38,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
throw new Error(
|
||||
"OpenAI Chat as an API translation target is not supported"
|
||||
);
|
||||
case "google-ai":
|
||||
throw new Error("add-key should not be used for this model.");
|
||||
case "mistral-ai":
|
||||
throw new Error("Mistral AI should never be translated");
|
||||
case "openai-image":
|
||||
assignedKey = keyPool.get("dall-e-3");
|
||||
break;
|
||||
@@ -71,23 +71,16 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
if (key.organizationId) {
|
||||
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
|
||||
}
|
||||
case "mistral-ai":
|
||||
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
|
||||
break;
|
||||
case "google-palm":
|
||||
const originalPath = proxyReq.path;
|
||||
proxyReq.path = originalPath.replace(
|
||||
/(\?.*)?$/,
|
||||
`?key=${assignedKey.key}`
|
||||
);
|
||||
break;
|
||||
case "azure":
|
||||
const azureKey = assignedKey.key;
|
||||
proxyReq.setHeader("api-key", azureKey);
|
||||
break;
|
||||
case "aws":
|
||||
throw new Error(
|
||||
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
|
||||
);
|
||||
case "google-ai":
|
||||
throw new Error("add-key should not be used for this service.");
|
||||
default:
|
||||
assertNever(assignedKey.service);
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@ import { HPMRequestCallback } from "../index";
|
||||
|
||||
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
|
||||
|
||||
class ForbiddenError extends Error {
|
||||
class ZoomerForbiddenError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "ForbiddenError";
|
||||
this.name = "ZoomerForbiddenError";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => {
|
||||
return;
|
||||
}
|
||||
|
||||
throw new ForbiddenError(
|
||||
throw new ZoomerForbiddenError(
|
||||
`Your access was terminated due to violation of our policies, please check your email for more information. If you believe this is in error and would like to appeal, please contact us through our help center at help.openai.com.`
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import { HPMRequestCallback } from "../index";
|
||||
import { config } from "../../../../config";
|
||||
import { ForbiddenError } from "../../../../shared/errors";
|
||||
import { getModelFamilyForRequest } from "../../../../shared/models";
|
||||
|
||||
/**
|
||||
* Ensures the selected model family is enabled by the proxy configuration.
|
||||
**/
|
||||
export const checkModelFamily: HPMRequestCallback = (proxyReq, req) => {
|
||||
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req, res) => {
|
||||
const family = getModelFamilyForRequest(req);
|
||||
if (!config.allowedModelFamilies.includes(family)) {
|
||||
throw new Error(`Model family ${family} is not permitted on this proxy`);
|
||||
throw new ForbiddenError(`Model family '${family}' is not enabled on this proxy`);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { HPMRequestCallback } from "../index";
|
||||
|
||||
/**
|
||||
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
|
||||
* before the proxy middleware. This function just assigns the path and headers
|
||||
* to the proxy request.
|
||||
* For AWS/Azure/Google requests, the body is signed earlier in the request
|
||||
* pipeline, before the proxy middleware. This function just assigns the path
|
||||
* and headers to the proxy request.
|
||||
*/
|
||||
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
|
||||
if (!req.signedRequest) {
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const addGoogleAIKey: RequestPreprocessor = (req) => {
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
|
||||
const serviceValid = req.service === "google-ai";
|
||||
if (!apisValid || !serviceValid) {
|
||||
throw new Error("addGoogleAIKey called on invalid request");
|
||||
}
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
const model = req.body.model;
|
||||
req.key = keyPool.get(model);
|
||||
|
||||
req.log.info(
|
||||
{ key: req.key.hash, model },
|
||||
"Assigned Google AI API key to request"
|
||||
);
|
||||
|
||||
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
|
||||
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
|
||||
|
||||
req.isStreaming = req.isStreaming || req.body.stream;
|
||||
delete req.body.stream;
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: "generativelanguage.googleapis.com",
|
||||
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
|
||||
headers: {
|
||||
["host"]: `generativelanguage.googleapis.com`,
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify(req.body),
|
||||
};
|
||||
};
|
||||
@@ -1,7 +1,11 @@
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { countTokens } from "../../../../shared/tokenization";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import type { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
import type {
|
||||
GoogleAIChatMessage,
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "./transform-outbound-payload";
|
||||
|
||||
/**
|
||||
* Given a request with an already-transformed body, counts the number of
|
||||
@@ -30,9 +34,15 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "google-palm": {
|
||||
req.outputTokens = req.body.maxOutputTokens;
|
||||
const prompt: string = req.body.prompt.text;
|
||||
case "google-ai": {
|
||||
req.outputTokens = req.body.generationConfig.maxOutputTokens;
|
||||
const prompt: GoogleAIChatMessage[] = req.body.contents;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "mistral-ai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: MistralAIChatMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ import { config } from "../../../../config";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { UserInputError } from "../../../../shared/errors";
|
||||
import { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
import {
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "./transform-outbound-payload";
|
||||
|
||||
const rejectedClients = new Map<string, number>();
|
||||
|
||||
@@ -53,8 +56,9 @@ function getPromptFromRequest(req: Request) {
|
||||
case "anthropic":
|
||||
return body.prompt;
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return body.messages
|
||||
.map((msg: OpenAIChatMessage) => {
|
||||
.map((msg: OpenAIChatMessage | MistralAIChatMessage) => {
|
||||
const text = Array.isArray(msg.content)
|
||||
? msg.content
|
||||
.map((c) => {
|
||||
@@ -68,7 +72,7 @@ function getPromptFromRequest(req: Request) {
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
return body.prompt;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
return body.prompt.text;
|
||||
default:
|
||||
assertNever(service);
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import { Request } from "express";
|
||||
import { APIFormat, LLMService } from "../../../../shared/key-management";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { LLMService } from "../../../../shared/models";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const setApiFormat = (api: {
|
||||
inApi: Request["inboundApi"];
|
||||
outApi: APIFormat;
|
||||
service: LLMService,
|
||||
service: LLMService;
|
||||
}): RequestPreprocessor => {
|
||||
return function configureRequestApiFormat (req) {
|
||||
return function configureRequestApiFormat(req) {
|
||||
req.inboundApi = api.inApi;
|
||||
req.outboundApi = api.outApi;
|
||||
req.service = api.service;
|
||||
|
||||
@@ -32,7 +32,7 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
}).parse(req.body);
|
||||
}).strip().parse(req.body);
|
||||
|
||||
const credential = getCredentialParts(req);
|
||||
const host = AMZ_HOST.replace("%REGION%", credential.region);
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../../config";
|
||||
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common";
|
||||
import {
|
||||
isTextGenerationRequest,
|
||||
isImageGenerationRequest,
|
||||
} from "../../common";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
|
||||
@@ -11,23 +14,24 @@ const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
|
||||
// TODO: move schemas to shared
|
||||
|
||||
// https://console.anthropic.com/docs/api/reference#-v1-complete
|
||||
export const AnthropicV1CompleteSchema = z.object({
|
||||
model: z.string(),
|
||||
prompt: z.string({
|
||||
required_error:
|
||||
"No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?",
|
||||
}),
|
||||
max_tokens_to_sample: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
|
||||
stop_sequences: z.array(z.string()).optional(),
|
||||
stream: z.boolean().optional().default(false),
|
||||
temperature: z.coerce.number().optional().default(1),
|
||||
top_k: z.coerce.number().optional(),
|
||||
top_p: z.coerce.number().optional(),
|
||||
metadata: z.any().optional(),
|
||||
});
|
||||
export const AnthropicV1CompleteSchema = z
|
||||
.object({
|
||||
model: z.string().max(100),
|
||||
prompt: z.string({
|
||||
required_error:
|
||||
"No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?",
|
||||
}),
|
||||
max_tokens_to_sample: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
|
||||
stop_sequences: z.array(z.string().max(500)).optional(),
|
||||
stream: z.boolean().optional().default(false),
|
||||
temperature: z.coerce.number().optional().default(1),
|
||||
top_k: z.coerce.number().optional(),
|
||||
top_p: z.coerce.number().optional(),
|
||||
})
|
||||
.strip();
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
const OpenAIV1ChatContentArraySchema = z.array(
|
||||
@@ -43,44 +47,48 @@ const OpenAIV1ChatContentArraySchema = z.array(
|
||||
])
|
||||
);
|
||||
|
||||
export const OpenAIV1ChatCompletionSchema = z.object({
|
||||
model: z.string(),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
|
||||
name: z.string().optional(),
|
||||
}),
|
||||
{
|
||||
required_error:
|
||||
"No `messages` found. Ensure you've set the correct completion endpoint.",
|
||||
invalid_type_error:
|
||||
"Messages were not formatted correctly. Refer to the OpenAI Chat API documentation for more information.",
|
||||
}
|
||||
),
|
||||
temperature: z.number().optional().default(1),
|
||||
top_p: z.number().optional().default(1),
|
||||
n: z
|
||||
.literal(1, {
|
||||
errorMap: () => ({
|
||||
message: "You may only request a single completion at a time.",
|
||||
export const OpenAIV1ChatCompletionSchema = z
|
||||
.object({
|
||||
model: z.string().max(100),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
|
||||
name: z.string().optional(),
|
||||
}),
|
||||
})
|
||||
.optional(),
|
||||
stream: z.boolean().optional().default(false),
|
||||
stop: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
max_tokens: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
.nullish()
|
||||
.default(16)
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
frequency_penalty: z.number().optional().default(0),
|
||||
presence_penalty: z.number().optional().default(0),
|
||||
logit_bias: z.any().optional(),
|
||||
user: z.string().optional(),
|
||||
seed: z.number().int().optional(),
|
||||
});
|
||||
{
|
||||
required_error:
|
||||
"No `messages` found. Ensure you've set the correct completion endpoint.",
|
||||
invalid_type_error:
|
||||
"Messages were not formatted correctly. Refer to the OpenAI Chat API documentation for more information.",
|
||||
}
|
||||
),
|
||||
temperature: z.number().optional().default(1),
|
||||
top_p: z.number().optional().default(1),
|
||||
n: z
|
||||
.literal(1, {
|
||||
errorMap: () => ({
|
||||
message: "You may only request a single completion at a time.",
|
||||
}),
|
||||
})
|
||||
.optional(),
|
||||
stream: z.boolean().optional().default(false),
|
||||
stop: z
|
||||
.union([z.string().max(500), z.array(z.string().max(500))])
|
||||
.optional(),
|
||||
max_tokens: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
.nullish()
|
||||
.default(16)
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
frequency_penalty: z.number().optional().default(0),
|
||||
presence_penalty: z.number().optional().default(0),
|
||||
logit_bias: z.any().optional(),
|
||||
user: z.string().max(500).optional(),
|
||||
seed: z.number().int().optional(),
|
||||
})
|
||||
.strip();
|
||||
|
||||
export type OpenAIChatMessage = z.infer<
|
||||
typeof OpenAIV1ChatCompletionSchema
|
||||
@@ -90,6 +98,7 @@ const OpenAIV1TextCompletionSchema = z
|
||||
.object({
|
||||
model: z
|
||||
.string()
|
||||
.max(100)
|
||||
.regex(
|
||||
/^gpt-3.5-turbo-instruct/,
|
||||
"Model must start with 'gpt-3.5-turbo-instruct'"
|
||||
@@ -101,50 +110,96 @@ const OpenAIV1TextCompletionSchema = z
|
||||
logprobs: z.number().int().nullish().default(null),
|
||||
echo: z.boolean().optional().default(false),
|
||||
best_of: z.literal(1).optional(),
|
||||
stop: z.union([z.string(), z.array(z.string()).max(4)]).optional(),
|
||||
suffix: z.string().optional(),
|
||||
stop: z
|
||||
.union([z.string().max(500), z.array(z.string().max(500)).max(4)])
|
||||
.optional(),
|
||||
suffix: z.string().max(1000).optional(),
|
||||
})
|
||||
.strip()
|
||||
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true }));
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
const OpenAIV1ImagesGenerationSchema = z.object({
|
||||
prompt: z.string().max(4000),
|
||||
model: z.string().optional(),
|
||||
quality: z.enum(["standard", "hd"]).optional().default("standard"),
|
||||
n: z.number().int().min(1).max(4).optional().default(1),
|
||||
response_format: z.enum(["url", "b64_json"]).optional(),
|
||||
size: z
|
||||
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
|
||||
.optional()
|
||||
.default("1024x1024"),
|
||||
style: z.enum(["vivid", "natural"]).optional().default("vivid"),
|
||||
user: z.string().optional(),
|
||||
});
|
||||
const OpenAIV1ImagesGenerationSchema = z
|
||||
.object({
|
||||
prompt: z.string().max(4000),
|
||||
model: z.string().max(100).optional(),
|
||||
quality: z.enum(["standard", "hd"]).optional().default("standard"),
|
||||
n: z.number().int().min(1).max(4).optional().default(1),
|
||||
response_format: z.enum(["url", "b64_json"]).optional(),
|
||||
size: z
|
||||
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
|
||||
.optional()
|
||||
.default("1024x1024"),
|
||||
style: z.enum(["vivid", "natural"]).optional().default("vivid"),
|
||||
user: z.string().max(500).optional(),
|
||||
})
|
||||
.strip();
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText
|
||||
const PalmV1GenerateTextSchema = z.object({
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
|
||||
const GoogleAIV1GenerateContentSchema = z
|
||||
.object({
|
||||
model: z.string().max(100), //actually specified in path but we need it for the router
|
||||
stream: z.boolean().optional().default(false), // also used for router
|
||||
contents: z.array(
|
||||
z.object({
|
||||
parts: z.array(z.object({ text: z.string() })),
|
||||
role: z.enum(["user", "model"]),
|
||||
})
|
||||
),
|
||||
tools: z.array(z.object({})).max(0).optional(),
|
||||
safetySettings: z.array(z.object({})).max(0).optional(),
|
||||
generationConfig: z.object({
|
||||
temperature: z.number().optional(),
|
||||
maxOutputTokens: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
.optional()
|
||||
.default(16)
|
||||
.transform((v) => Math.min(v, 1024)), // TODO: Add config
|
||||
candidateCount: z.literal(1).optional(),
|
||||
topP: z.number().optional(),
|
||||
topK: z.number().optional(),
|
||||
stopSequences: z.array(z.string().max(500)).max(5).optional(),
|
||||
}),
|
||||
})
|
||||
.strip();
|
||||
|
||||
export type GoogleAIChatMessage = z.infer<
|
||||
typeof GoogleAIV1GenerateContentSchema
|
||||
>["contents"][0];
|
||||
|
||||
// https://docs.mistral.ai/api#operation/createChatCompletion
|
||||
const MistralAIV1ChatCompletionsSchema = z.object({
|
||||
model: z.string(),
|
||||
prompt: z.object({ text: z.string() }),
|
||||
temperature: z.number().optional(),
|
||||
maxOutputTokens: z.coerce
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.string(),
|
||||
})
|
||||
),
|
||||
temperature: z.number().optional().default(0.7),
|
||||
top_p: z.number().optional().default(1),
|
||||
max_tokens: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
.optional()
|
||||
.default(16)
|
||||
.transform((v) => Math.min(v, 1024)), // TODO: Add config
|
||||
candidateCount: z.literal(1).optional(),
|
||||
topP: z.number().optional(),
|
||||
topK: z.number().optional(),
|
||||
safetySettings: z.array(z.object({})).max(0).optional(),
|
||||
stopSequences: z.array(z.string()).max(5).optional(),
|
||||
.nullish()
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
stream: z.boolean().optional().default(false),
|
||||
safe_mode: z.boolean().optional().default(false),
|
||||
random_seed: z.number().int().optional(),
|
||||
});
|
||||
|
||||
export type MistralAIChatMessage = z.infer<
|
||||
typeof MistralAIV1ChatCompletionsSchema
|
||||
>["messages"][0];
|
||||
|
||||
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
anthropic: AnthropicV1CompleteSchema,
|
||||
openai: OpenAIV1ChatCompletionSchema,
|
||||
"openai-text": OpenAIV1TextCompletionSchema,
|
||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||
"google-palm": PalmV1GenerateTextSchema,
|
||||
"google-ai": GoogleAIV1GenerateContentSchema,
|
||||
"mistral-ai": MistralAIV1ChatCompletionsSchema,
|
||||
};
|
||||
|
||||
/** Transforms an incoming request body to one that matches the target API. */
|
||||
@@ -174,8 +229,8 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (req.inboundApi === "openai" && req.outboundApi === "google-palm") {
|
||||
req.body = openaiToPalm(req);
|
||||
if (req.inboundApi === "openai" && req.outboundApi === "google-ai") {
|
||||
req.body = openaiToGoogleAI(req);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -310,7 +365,9 @@ function openaiToOpenaiImage(req: Request) {
|
||||
return OpenAIV1ImagesGenerationSchema.parse(transformed);
|
||||
}
|
||||
|
||||
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
function openaiToGoogleAI(
|
||||
req: Request
|
||||
): z.infer<typeof GoogleAIV1GenerateContentSchema> {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse({
|
||||
...body,
|
||||
@@ -319,40 +376,77 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-Palm request"
|
||||
"Invalid OpenAI-to-Google AI request"
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = flattenOpenAIChatMessages(messages);
|
||||
const foundNames = new Set<string>();
|
||||
const contents = messages
|
||||
.map((m) => {
|
||||
const role = m.role === "assistant" ? "model" : "user";
|
||||
// Detects character names so we can set stop sequences for them as Gemini
|
||||
// is prone to continuing as the next character.
|
||||
// If names are not available, we'll still try to prefix the message
|
||||
// with generic names so we can set stops for them but they don't work
|
||||
// as well as real names.
|
||||
const text = flattenOpenAIMessageContent(m.content);
|
||||
const propName = m.name?.trim();
|
||||
const textName =
|
||||
m.role === "system" ? "" : text.match(/^(.{0,50}?): /)?.[1]?.trim();
|
||||
const name =
|
||||
propName || textName || (role === "model" ? "Character" : "User");
|
||||
|
||||
foundNames.add(name);
|
||||
|
||||
// Prefixing messages with their character name seems to help avoid
|
||||
// Gemini trying to continue as the next character, or at the very least
|
||||
// ensures it will hit the stop sequence. Otherwise it will start a new
|
||||
// paragraph and switch perspectives.
|
||||
// The response will be very likely to include this prefix so frontends
|
||||
// will need to strip it out.
|
||||
const textPrefix = textName ? "" : `${name}: `;
|
||||
return {
|
||||
parts: [{ text: textPrefix + text }],
|
||||
role: m.role === "assistant" ? ("model" as const) : ("user" as const),
|
||||
};
|
||||
})
|
||||
.reduce<GoogleAIChatMessage[]>((acc, msg) => {
|
||||
const last = acc[acc.length - 1];
|
||||
if (last?.role === msg.role) {
|
||||
last.parts[0].text += "\n\n" + msg.parts[0].text;
|
||||
} else {
|
||||
acc.push(msg);
|
||||
}
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
? rest.stop
|
||||
: [rest.stop]
|
||||
: [];
|
||||
|
||||
stops.push("\n\nUser:");
|
||||
stops = [...new Set(stops)];
|
||||
|
||||
z.array(z.string()).max(5).parse(stops);
|
||||
stops.push(...Array.from(foundNames).map((name) => `\n${name}:`));
|
||||
stops = [...new Set(stops)].slice(0, 5);
|
||||
|
||||
return {
|
||||
prompt: { text: prompt },
|
||||
maxOutputTokens: rest.max_tokens,
|
||||
stopSequences: stops,
|
||||
model: "text-bison-001",
|
||||
topP: rest.top_p,
|
||||
temperature: rest.temperature,
|
||||
model: "gemini-pro",
|
||||
stream: rest.stream,
|
||||
contents,
|
||||
tools: [],
|
||||
generationConfig: {
|
||||
maxOutputTokens: rest.max_tokens,
|
||||
stopSequences: stops,
|
||||
topP: rest.top_p,
|
||||
topK: 40, // openai schema doesn't have this, google ai defaults to 40
|
||||
temperature: rest.temperature,
|
||||
},
|
||||
safetySettings: [
|
||||
{ category: "HARM_CATEGORY_UNSPECIFIED", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_DEROGATORY", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_TOXICITY", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_VIOLENCE", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_SEXUAL", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_MEDICAL", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@ import { RequestPreprocessor } from "../index";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
const BISON_MAX_CONTEXT = 8100;
|
||||
const GOOGLE_AI_MAX_CONTEXT = 32000;
|
||||
const MISTRAL_AI_MAX_CONTENT = 32768;
|
||||
|
||||
/**
|
||||
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
|
||||
@@ -31,9 +32,11 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
case "anthropic":
|
||||
proxyMax = CLAUDE_MAX_CONTEXT;
|
||||
break;
|
||||
case "google-palm":
|
||||
proxyMax = BISON_MAX_CONTEXT;
|
||||
case "google-ai":
|
||||
proxyMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
break;
|
||||
case "mistral-ai":
|
||||
proxyMax = MISTRAL_AI_MAX_CONTENT;
|
||||
case "openai-image":
|
||||
return;
|
||||
default:
|
||||
@@ -62,8 +65,10 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
modelMax = 100000;
|
||||
} else if (model.match(/^claude-2/)) {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^text-bison-\d{3}$/)) {
|
||||
modelMax = BISON_MAX_CONTEXT;
|
||||
} else if (model.match(/^gemini-\d{3}$/)) {
|
||||
modelMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
|
||||
modelMax = MISTRAL_AI_MAX_CONTENT;
|
||||
} else if (model.match(/^anthropic\.claude/)) {
|
||||
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
|
||||
modelMax = 100000;
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import express from "express";
|
||||
import { pipeline } from "stream";
|
||||
import { promisify } from "util";
|
||||
import {
|
||||
buildFakeSse,
|
||||
makeCompletionSSE,
|
||||
copySseResponseHeaders,
|
||||
initializeSseStream,
|
||||
} from "../../../shared/streaming";
|
||||
@@ -59,7 +58,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
||||
const contentType = proxyRes.headers["content-type"];
|
||||
|
||||
const adapter = new SSEStreamAdapter({ contentType });
|
||||
const adapter = new SSEStreamAdapter({ contentType, api: req.outboundApi });
|
||||
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||
const transformer = new SSEMessageTransformer({
|
||||
inputFormat: req.outboundApi,
|
||||
@@ -89,10 +88,20 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
`Re-enqueueing request due to retryable error during streaming response.`
|
||||
);
|
||||
req.retryCount++;
|
||||
enqueue(req);
|
||||
await enqueue(req);
|
||||
} else {
|
||||
const errorEvent = buildFakeSse("stream-error", err.message, req);
|
||||
res.write(`${errorEvent}data: [DONE]\n\n`);
|
||||
const { message, stack, lastEvent } = err;
|
||||
const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined"
|
||||
const errorEvent = makeCompletionSSE({
|
||||
format: req.inboundApi,
|
||||
title: "Proxy stream error",
|
||||
message: "An unexpected error occurred while streaming the response.",
|
||||
obj: { message, stack, lastEvent: eventText },
|
||||
reqId: req.id,
|
||||
model: req.body?.model,
|
||||
});
|
||||
res.write(errorEvent);
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
}
|
||||
throw err;
|
||||
|
||||
@@ -152,13 +152,13 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
||||
};
|
||||
};
|
||||
|
||||
function reenqueueRequest(req: Request) {
|
||||
async function reenqueueRequest(req: Request) {
|
||||
req.log.info(
|
||||
{ key: req.key?.hash, retryCount: req.retryCount },
|
||||
`Re-enqueueing request due to retryable error`
|
||||
);
|
||||
req.retryCount++;
|
||||
enqueue(req);
|
||||
await enqueue(req);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -192,7 +192,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
||||
} else {
|
||||
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
|
||||
req.log.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
|
||||
writeErrorResponse(req, res, 500, {
|
||||
writeErrorResponse(req, res, 500, "Internal Server Error", {
|
||||
error: errorMessage,
|
||||
contentEncoding,
|
||||
});
|
||||
@@ -209,7 +209,9 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
||||
} catch (error: any) {
|
||||
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
|
||||
req.log.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
|
||||
writeErrorResponse(req, res, 500, { error: errorMessage });
|
||||
writeErrorResponse(req, res, 500, "Internal Server Error", {
|
||||
error: errorMessage,
|
||||
});
|
||||
return reject(errorMessage);
|
||||
}
|
||||
});
|
||||
@@ -237,6 +239,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
body
|
||||
) => {
|
||||
const statusCode = proxyRes.statusCode || 500;
|
||||
const statusMessage = proxyRes.statusMessage || "Internal Server Error";
|
||||
|
||||
if (statusCode < 400) {
|
||||
return;
|
||||
@@ -253,16 +256,16 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
} catch (parseError) {
|
||||
// Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy
|
||||
const hash = req.key?.hash;
|
||||
const statusMessage = proxyRes.statusMessage || "Unknown error";
|
||||
req.log.warn({ statusCode, statusMessage, key: hash }, parseError.message);
|
||||
|
||||
const errorObject = {
|
||||
statusCode,
|
||||
statusMessage: proxyRes.statusMessage,
|
||||
error: parseError.message,
|
||||
proxy_note: `This is likely a temporary error with the upstream service.`,
|
||||
status: statusCode,
|
||||
statusMessage,
|
||||
proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service.`,
|
||||
};
|
||||
writeErrorResponse(req, res, statusCode, errorObject);
|
||||
|
||||
writeErrorResponse(req, res, statusCode, statusMessage, errorObject);
|
||||
throw new HttpError(statusCode, parseError.message);
|
||||
}
|
||||
|
||||
@@ -288,7 +291,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
// For Anthropic, this is usually due to missing preamble.
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
case "azure":
|
||||
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||
if (filteredCodes.includes(errorPayload.error?.code)) {
|
||||
@@ -297,14 +301,14 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
|
||||
// For some reason, some models return this 400 error instead of the
|
||||
// same 429 billing error that other models return.
|
||||
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||
await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||
} else {
|
||||
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
|
||||
}
|
||||
break;
|
||||
case "anthropic":
|
||||
case "aws":
|
||||
maybeHandleMissingPreambleError(req, errorPayload);
|
||||
await maybeHandleMissingPreambleError(req, errorPayload);
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
@@ -314,7 +318,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
|
||||
} else if (statusCode === 403) {
|
||||
// Amazon is the only service that returns 403.
|
||||
if (service === "anthropic") {
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
|
||||
return;
|
||||
}
|
||||
switch (errorType) {
|
||||
case "UnrecognizedClientException":
|
||||
// Key is invalid.
|
||||
@@ -335,19 +343,20 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
} else if (statusCode === 429) {
|
||||
switch (service) {
|
||||
case "openai":
|
||||
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||
await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||
break;
|
||||
case "anthropic":
|
||||
handleAnthropicRateLimitError(req, errorPayload);
|
||||
await handleAnthropicRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "aws":
|
||||
handleAwsRateLimitError(req, errorPayload);
|
||||
await handleAwsRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "azure":
|
||||
handleAzureRateLimitError(req, errorPayload);
|
||||
case "mistral-ai":
|
||||
await handleAzureRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "google-palm":
|
||||
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
|
||||
case "google-ai":
|
||||
await handleGoogleAIRateLimitError(req, errorPayload);
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
@@ -369,8 +378,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "anthropic":
|
||||
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
|
||||
break;
|
||||
case "google-palm":
|
||||
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`;
|
||||
case "google-ai":
|
||||
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
|
||||
break;
|
||||
case "mistral-ai":
|
||||
errorPayload.proxy_note = `The requested Mistral AI model might not exist, or the key might not be provisioned for it.`;
|
||||
break;
|
||||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
@@ -393,7 +405,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
);
|
||||
}
|
||||
|
||||
writeErrorResponse(req, res, statusCode, errorPayload);
|
||||
writeErrorResponse(req, res, statusCode, statusMessage, errorPayload);
|
||||
throw new HttpError(statusCode, errorPayload.error?.message);
|
||||
};
|
||||
|
||||
@@ -416,7 +428,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
function maybeHandleMissingPreambleError(
|
||||
async function maybeHandleMissingPreambleError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
@@ -429,27 +441,27 @@ function maybeHandleMissingPreambleError(
|
||||
"Request failed due to missing preamble. Key will be marked as such for subsequent requests."
|
||||
);
|
||||
keyPool.update(req.key!, { requiresPreamble: true });
|
||||
reenqueueRequest(req);
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError("Claude request re-enqueued to add preamble.");
|
||||
} else {
|
||||
errorPayload.proxy_note = `Proxy received unrecognized error from Anthropic. Check the specific error for more information.`;
|
||||
}
|
||||
}
|
||||
|
||||
function handleAnthropicRateLimitError(
|
||||
async function handleAnthropicRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
if (errorPayload.error?.type === "rate_limit_error") {
|
||||
keyPool.markRateLimited(req.key!);
|
||||
reenqueueRequest(req);
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError("Claude rate-limited request re-enqueued.");
|
||||
} else {
|
||||
errorPayload.proxy_note = `Unrecognized rate limit error from Anthropic. Key may be over quota.`;
|
||||
}
|
||||
}
|
||||
|
||||
function handleAwsRateLimitError(
|
||||
async function handleAwsRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
@@ -457,7 +469,7 @@ function handleAwsRateLimitError(
|
||||
switch (errorType) {
|
||||
case "ThrottlingException":
|
||||
keyPool.markRateLimited(req.key!);
|
||||
reenqueueRequest(req);
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError("AWS rate-limited request re-enqueued.");
|
||||
case "ModelNotReadyException":
|
||||
errorPayload.proxy_note = `The requested model is overloaded. Try again in a few seconds.`;
|
||||
@@ -467,11 +479,11 @@ function handleAwsRateLimitError(
|
||||
}
|
||||
}
|
||||
|
||||
function handleOpenAIRateLimitError(
|
||||
async function handleOpenAIRateLimitError(
|
||||
req: Request,
|
||||
tryAgainMessage: string,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
): Record<string, any> {
|
||||
): Promise<Record<string, any>> {
|
||||
const type = errorPayload.error?.type;
|
||||
switch (type) {
|
||||
case "insufficient_quota":
|
||||
@@ -500,8 +512,58 @@ function handleOpenAIRateLimitError(
|
||||
}
|
||||
|
||||
// Per-minute request or token rate limit is exceeded, which we can retry
|
||||
reenqueueRequest(req);
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
// WIP/nonfunctional
|
||||
// case "tokens_usage_based":
|
||||
// // Weird new rate limit type that seems limited to preview models.
|
||||
// // Distinct from `tokens` type. Can be per-minute or per-day.
|
||||
//
|
||||
// // I've seen reports of this error for 500k tokens/day and 10k tokens/min.
|
||||
// // 10k tokens per minute is problematic, because this is much less than
|
||||
// // GPT4-Turbo's max context size for a single prompt and is effectively a
|
||||
// // cap on the max context size for just that key+model, which the app is
|
||||
// // not able to deal with.
|
||||
//
|
||||
// // Similarly if there is a 500k tokens per day limit and 450k tokens have
|
||||
// // been used today, the max context for that key becomes 50k tokens until
|
||||
// // the next day and becomes progressively smaller as more tokens are used.
|
||||
//
|
||||
// // To work around these keys we will first retry the request a few times.
|
||||
// // After that we will reject the request, and if it's a per-day limit we
|
||||
// // will also disable the key.
|
||||
//
|
||||
// // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per day: Limit 500000, Used 460000, Requested 50000"
|
||||
// // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per min: Limit 10000, Requested 40000"
|
||||
//
|
||||
// const regex =
|
||||
// /Rate limit reached for .+ in organization .+ on \w+ per (day|min): Limit (\d+)(?:, Used (\d+))?, Requested (\d+)/;
|
||||
// const [, period, limit, used, requested] =
|
||||
// errorPayload.error?.message?.match(regex) || [];
|
||||
//
|
||||
// req.log.warn(
|
||||
// { key: req.key?.hash, period, limit, used, requested },
|
||||
// "Received `tokens_usage_based` rate limit error from OpenAI."
|
||||
// );
|
||||
//
|
||||
// if (!period || !limit || !requested) {
|
||||
// errorPayload.proxy_note = `Unrecognized rate limit error from OpenAI. (${errorPayload.error?.message})`;
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// if (req.retryCount < 2) {
|
||||
// await reenqueueRequest(req);
|
||||
// throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
// }
|
||||
//
|
||||
// if (period === "min") {
|
||||
// errorPayload.proxy_note = `Assigned key can't be used for prompts longer than ${limit} tokens, and no other keys are available right now. Reduce the length of your prompt or try again in a few minutes.`;
|
||||
// } else {
|
||||
// errorPayload.proxy_note = `Assigned key has reached its per-day request limit for this model. Try another model.`;
|
||||
// }
|
||||
//
|
||||
// keyPool.markRateLimited(req.key!);
|
||||
// break;
|
||||
default:
|
||||
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
|
||||
break;
|
||||
@@ -509,7 +571,7 @@ function handleOpenAIRateLimitError(
|
||||
return errorPayload;
|
||||
}
|
||||
|
||||
function handleAzureRateLimitError(
|
||||
async function handleAzureRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
@@ -517,7 +579,7 @@ function handleAzureRateLimitError(
|
||||
switch (code) {
|
||||
case "429":
|
||||
keyPool.markRateLimited(req.key!);
|
||||
reenqueueRequest(req);
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
default:
|
||||
errorPayload.proxy_note = `Unrecognized rate limit error from Azure (${code}). Please report this.`;
|
||||
@@ -525,6 +587,23 @@ function handleAzureRateLimitError(
|
||||
}
|
||||
}
|
||||
|
||||
//{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}
|
||||
async function handleGoogleAIRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
const status = errorPayload.error?.status;
|
||||
switch (status) {
|
||||
case "RESOURCE_EXHAUSTED":
|
||||
keyPool.markRateLimited(req.key!);
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
default:
|
||||
errorPayload.proxy_note = `Unrecognized rate limit error from Google AI (${status}). Please report this.`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
||||
if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
|
||||
const model = req.body.model;
|
||||
|
||||
@@ -9,7 +9,10 @@ import {
|
||||
} from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload";
|
||||
import {
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../request/preprocessors/transform-outbound-payload";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
@@ -54,12 +57,13 @@ type OaiImageResult = {
|
||||
const getPromptForRequest = (
|
||||
req: Request,
|
||||
responseBody: Record<string, any>
|
||||
): string | OpenAIChatMessage[] | OaiImageResult => {
|
||||
): string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult => {
|
||||
// Since the prompt logger only runs after the request has been proxied, we
|
||||
// can assume the body has already been transformed to the target API's
|
||||
// format.
|
||||
switch (req.outboundApi) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return req.body.messages;
|
||||
case "openai-text":
|
||||
return req.body.prompt;
|
||||
@@ -73,7 +77,7 @@ const getPromptForRequest = (
|
||||
};
|
||||
case "anthropic":
|
||||
return req.body.prompt;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
return req.body.prompt.text;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
@@ -81,7 +85,7 @@ const getPromptForRequest = (
|
||||
};
|
||||
|
||||
const flattenMessages = (
|
||||
val: string | OpenAIChatMessage[] | OaiImageResult
|
||||
val: string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult
|
||||
): string => {
|
||||
if (typeof val === "string") {
|
||||
return val.trim();
|
||||
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
mergeEventsForAnthropic,
|
||||
mergeEventsForOpenAIChat,
|
||||
mergeEventsForOpenAIText,
|
||||
OpenAIChatCompletionStreamEvent
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
} from "./index";
|
||||
|
||||
/**
|
||||
@@ -27,12 +27,13 @@ export class EventAggregator {
|
||||
getFinalResponse() {
|
||||
switch (this.format) {
|
||||
case "openai":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
return mergeEventsForOpenAIChat(this.events);
|
||||
case "openai-text":
|
||||
return mergeEventsForOpenAIText(this.events);
|
||||
case "anthropic":
|
||||
return mergeEventsForAnthropic(this.events);
|
||||
case "google-palm":
|
||||
case "openai-image":
|
||||
throw new Error(`SSE aggregation not supported for ${this.format}`);
|
||||
default:
|
||||
|
||||
@@ -25,6 +25,8 @@ export type StreamingCompletionTransformer = (
|
||||
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
|
||||
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
|
||||
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
|
||||
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
|
||||
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
|
||||
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
|
||||
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
|
||||
export { mergeEventsForAnthropic } from "./aggregators/anthropic";
|
||||
|
||||
@@ -7,9 +7,10 @@ import {
|
||||
anthropicV2ToOpenAI,
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
openAITextToOpenAIChat,
|
||||
googleAIToOpenAI,
|
||||
passthroughToOpenAI,
|
||||
StreamingCompletionTransformer,
|
||||
} from "./index";
|
||||
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
|
||||
|
||||
const genlog = logger.child({ module: "sse-transformer" });
|
||||
|
||||
@@ -92,6 +93,7 @@ export class SSEMessageTransformer extends Transform {
|
||||
this.push(transformedMessage);
|
||||
callback();
|
||||
} catch (err) {
|
||||
err.lastEvent = chunk?.toString();
|
||||
this.log.error(err, "Error transforming SSE message");
|
||||
callback(err);
|
||||
}
|
||||
@@ -104,6 +106,7 @@ function getTransformer(
|
||||
): StreamingCompletionTransformer {
|
||||
switch (responseApi) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return passthroughToOpenAI;
|
||||
case "openai-text":
|
||||
return openAITextToOpenAIChat;
|
||||
@@ -111,7 +114,8 @@ function getTransformer(
|
||||
return version === "2023-01-01"
|
||||
? anthropicV1ToOpenAI
|
||||
: anthropicV2ToOpenAI;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
return googleAIToOpenAI;
|
||||
case "openai-image":
|
||||
throw new Error(`SSE transformation not supported for ${responseApi}`);
|
||||
default:
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
import { Transform, TransformOptions } from "stream";
|
||||
|
||||
import { StringDecoder } from "string_decoder";
|
||||
// @ts-ignore
|
||||
import { Parser } from "lifion-aws-event-stream";
|
||||
import { logger } from "../../../../logger";
|
||||
import { RetryableError } from "../index";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import StreamArray from "stream-json/streamers/StreamArray";
|
||||
import { makeCompletionSSE } from "../../../../shared/streaming";
|
||||
|
||||
const log = logger.child({ module: "sse-stream-adapter" });
|
||||
|
||||
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string };
|
||||
type SSEStreamAdapterOptions = TransformOptions & {
|
||||
contentType?: string;
|
||||
api: APIFormat;
|
||||
};
|
||||
type AwsEventStreamMessage = {
|
||||
headers: {
|
||||
":message-type": "event" | "exception";
|
||||
@@ -21,20 +29,31 @@ type AwsEventStreamMessage = {
|
||||
*/
|
||||
export class SSEStreamAdapter extends Transform {
|
||||
private readonly isAwsStream;
|
||||
private parser = new Parser();
|
||||
private readonly isGoogleStream;
|
||||
private awsParser = new Parser();
|
||||
private jsonParser = StreamArray.withParser();
|
||||
private partialMessage = "";
|
||||
private decoder = new StringDecoder("utf8");
|
||||
|
||||
constructor(options?: SSEStreamAdapterOptions) {
|
||||
super(options);
|
||||
this.isAwsStream =
|
||||
options?.contentType === "application/vnd.amazon.eventstream";
|
||||
this.isGoogleStream = options?.api === "google-ai";
|
||||
|
||||
this.parser.on("data", (data: AwsEventStreamMessage) => {
|
||||
this.awsParser.on("data", (data: AwsEventStreamMessage) => {
|
||||
const message = this.processAwsEvent(data);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
}
|
||||
});
|
||||
|
||||
this.jsonParser.on("data", (data: { value: any }) => {
|
||||
const message = this.processGoogleValue(data.value);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
|
||||
@@ -53,11 +72,16 @@ export class SSEStreamAdapter extends Transform {
|
||||
);
|
||||
throw new RetryableError("AWS request throttled mid-stream");
|
||||
} else {
|
||||
log.error(
|
||||
{ event: eventStr },
|
||||
"Received unexpected AWS event stream message"
|
||||
);
|
||||
return getFakeErrorCompletion("proxy AWS error", eventStr);
|
||||
log.error({ event: eventStr }, "Received bad AWS stream event");
|
||||
return makeCompletionSSE({
|
||||
format: "anthropic",
|
||||
title: "Proxy stream error",
|
||||
message:
|
||||
"The proxy received malformed or unexpected data from AWS while streaming.",
|
||||
obj: event,
|
||||
reqId: "proxy-sse-adapter-message",
|
||||
model: "",
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const { bytes } = payload;
|
||||
@@ -71,44 +95,62 @@ export class SSEStreamAdapter extends Transform {
|
||||
}
|
||||
}
|
||||
|
||||
// Google doesn't use event streams and just sends elements in an array over
|
||||
// a long-lived HTTP connection. Needs stream-json to parse the array.
|
||||
protected processGoogleValue(value: any): string | null {
|
||||
try {
|
||||
const candidates = value.candidates ?? [{}];
|
||||
const hasParts = candidates[0].content?.parts?.length > 0;
|
||||
if (hasParts) {
|
||||
return `data: ${JSON.stringify(value)}`;
|
||||
} else {
|
||||
log.error({ event: value }, "Received bad Google AI event");
|
||||
return `data: ${makeCompletionSSE({
|
||||
format: "google-ai",
|
||||
title: "Proxy stream error",
|
||||
message:
|
||||
"The proxy received malformed or unexpected data from Google AI while streaming.",
|
||||
obj: value,
|
||||
reqId: "proxy-sse-adapter-message",
|
||||
model: "",
|
||||
})}`;
|
||||
}
|
||||
} catch (error) {
|
||||
error.lastEvent = value;
|
||||
this.emit("error", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
|
||||
try {
|
||||
if (this.isAwsStream) {
|
||||
this.parser.write(chunk);
|
||||
this.awsParser.write(chunk);
|
||||
} else if (this.isGoogleStream) {
|
||||
this.jsonParser.write(chunk);
|
||||
} else {
|
||||
// We may receive multiple (or partial) SSE messages in a single chunk,
|
||||
// so we need to buffer and emit separate stream events for full
|
||||
// messages so we can parse/transform them properly.
|
||||
const str = chunk.toString("utf8");
|
||||
const str = this.decoder.write(chunk);
|
||||
|
||||
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/);
|
||||
const fullMessages = (this.partialMessage + str).split(
|
||||
/\r\r|\n\n|\r\n\r\n/
|
||||
);
|
||||
this.partialMessage = fullMessages.pop() || "";
|
||||
|
||||
for (const message of fullMessages) {
|
||||
// Mixing line endings will break some clients and our request queue
|
||||
// will have already sent \n for heartbeats, so we need to normalize
|
||||
// to \n.
|
||||
this.push(message.replace(/\r\n/g, "\n") + "\n\n");
|
||||
this.push(message.replace(/\r\n?/g, "\n") + "\n\n");
|
||||
}
|
||||
}
|
||||
callback();
|
||||
} catch (error) {
|
||||
error.lastEvent = chunk?.toString();
|
||||
this.emit("error", error);
|
||||
callback(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getFakeErrorCompletion(type: string, message: string) {
|
||||
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
|
||||
const fakeEvent = JSON.stringify({
|
||||
log_id: "aws-proxy-sse-message",
|
||||
stop_reason: type,
|
||||
completion:
|
||||
"\nProxy encountered an error during streaming response.\n" + content,
|
||||
truncated: false,
|
||||
stop: null,
|
||||
model: "",
|
||||
});
|
||||
return ["event: completion", `data: ${fakeEvent}\n\n`].join("\n");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import { StreamingCompletionTransformer } from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
import { logger } from "../../../../../logger";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "google-ai-to-openai",
|
||||
});
|
||||
|
||||
type GoogleAIStreamEvent = {
|
||||
candidates: {
|
||||
content: { parts: { text: string }[]; role: string };
|
||||
finishReason?: "STOP" | "MAX_TOKENS" | "SAFETY" | "RECITATION" | "OTHER";
|
||||
index: number;
|
||||
tokenCount?: number;
|
||||
safetyRatings: { category: string; probability: string }[];
|
||||
}[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms an incoming Google AI SSE to an equivalent OpenAI
|
||||
* chat.completion.chunk SSE.
|
||||
*/
|
||||
export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
|
||||
const { data, index } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const completionEvent = asCompletion(rawEvent);
|
||||
if (!completionEvent) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const parts = completionEvent.candidates[0].content.parts;
|
||||
let content = parts[0]?.text ?? "";
|
||||
|
||||
// If this is the first chunk, try stripping speaker names from the response
|
||||
// e.g. "John: Hello" -> "Hello"
|
||||
if (index === 0) {
|
||||
content = content.replace(/^(.*?): /, "").trim();
|
||||
}
|
||||
|
||||
const newEvent = {
|
||||
id: "goo-" + params.fallbackId,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { content },
|
||||
finish_reason: completionEvent.candidates[0].finishReason ?? null,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
return { position: -1, event: newEvent };
|
||||
};
|
||||
|
||||
function asCompletion(event: ServerSentEvent): GoogleAIStreamEvent | null {
|
||||
try {
|
||||
const parsed = JSON.parse(event.data) as GoogleAIStreamEvent;
|
||||
if (parsed.candidates?.length > 0) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid event");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
import { RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import {
|
||||
getMistralAIModelFamily,
|
||||
MistralAIModelFamily,
|
||||
ModelFamily,
|
||||
} from "../shared/models";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
// https://docs.mistral.ai/platform/endpoints
|
||||
export const KNOWN_MISTRAL_AI_MODELS = [
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
];
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
export function generateModelList(models = KNOWN_MISTRAL_AI_MODELS) {
|
||||
let available = new Set<MistralAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "mistral-ai") continue;
|
||||
key.modelFamilies.forEach((family) =>
|
||||
available.add(family as MistralAIModelFamily)
|
||||
);
|
||||
}
|
||||
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||
|
||||
return models
|
||||
.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "mistral-ai",
|
||||
}))
|
||||
.filter((model) => available.has(getMistralAIModelFamily(model.id)));
|
||||
}
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
|
||||
const result = generateModelList();
|
||||
modelsCache = { object: "list", data: result };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
res.status(200).json(modelsCache);
|
||||
};
|
||||
|
||||
const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
const mistralAIProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://api.mistral.ai",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [addKey, finalizeBody],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([mistralAIResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const mistralAIRouter = Router();
|
||||
mistralAIRouter.get("/v1/models", handleModelRequest);
|
||||
// General chat completion endpoint.
|
||||
mistralAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "mistral-ai",
|
||||
outApi: "mistral-ai",
|
||||
service: "mistral-ai",
|
||||
}),
|
||||
mistralAIProxy
|
||||
);
|
||||
|
||||
export const mistralAI = mistralAIRouter;
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
} from "./middleware/response";
|
||||
import { generateModelList } from "./openai";
|
||||
import {
|
||||
mirrorGeneratedImage,
|
||||
OpenAIImageGenerationResult,
|
||||
} from "../shared/file-storage/mirror-generated-image";
|
||||
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import * as http from "http";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
forceModel,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
if (!config.googlePalmKey) return { object: "list", data: [] };
|
||||
|
||||
const bisonVariants = ["text-bison-001"];
|
||||
|
||||
const models = bisonVariants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "google",
|
||||
permission: [],
|
||||
root: "palm",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const palmResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.inboundApi === "openai") {
|
||||
req.log.info("Transforming Google PaLM response to OpenAI format");
|
||||
body = transformPalmResponse(body, req);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
// TODO: PaLM has no streaming capability which will pose a problem here if
|
||||
// requests wait in the queue for too long. Probably need to fake streaming
|
||||
// and return the entire completion in one stream event using the other
|
||||
// response handler.
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a model response from the Anthropic API to match those from the
|
||||
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
|
||||
* is only used for non-streaming requests as streaming requests are handled
|
||||
* on-the-fly.
|
||||
*/
|
||||
function transformPalmResponse(
|
||||
palmRespBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
return {
|
||||
id: "plm-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: palmRespBody.candidates[0].output,
|
||||
},
|
||||
finish_reason: null, // palm doesn't return this
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
|
||||
if (req.body.stream) {
|
||||
throw new Error("Google PaLM API doesn't support streaming requests");
|
||||
}
|
||||
|
||||
// PaLM API specifies the model in the URL path, not the request body. This
|
||||
// doesn't work well with our rewriter architecture, so we need to manually
|
||||
// fix it here.
|
||||
|
||||
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateText
|
||||
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateMessage
|
||||
|
||||
// The chat api (generateMessage) is not very useful at this time as it has
|
||||
// few params and no adjustable safety settings.
|
||||
|
||||
proxyReq.path = proxyReq.path.replace(
|
||||
/^\/v1\/chat\/completions/,
|
||||
`/v1beta2/models/${req.body.model}:generateText`
|
||||
);
|
||||
}
|
||||
|
||||
const googlePalmProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://generativelanguage.googleapis.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([palmResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const palmRouter = Router();
|
||||
palmRouter.get("/v1/models", handleModelRequest);
|
||||
// OpenAI-to-Google PaLM compatibility endpoint.
|
||||
palmRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "google-palm", service: "google-palm" },
|
||||
{ afterTransform: [forceModel("text-bison-001")] }
|
||||
),
|
||||
googlePalmProxy
|
||||
);
|
||||
|
||||
export const googlePalm = palmRouter;
|
||||
+56
-31
@@ -14,8 +14,12 @@
|
||||
import crypto from "crypto";
|
||||
import type { Handler, Request } from "express";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models";
|
||||
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
|
||||
import {
|
||||
getModelFamilyForRequest,
|
||||
MODEL_FAMILIES,
|
||||
ModelFamily,
|
||||
} from "../shared/models";
|
||||
import { makeCompletionSSE, initializeSseStream } from "../shared/streaming";
|
||||
import { logger } from "../logger";
|
||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||
import { RequestPreprocessor } from "./middleware/request";
|
||||
@@ -37,6 +41,7 @@ const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "50");
|
||||
const PAYLOAD_SCALE_FACTOR = parseFloat(
|
||||
process.env.PAYLOAD_SCALE_FACTOR ?? "6"
|
||||
);
|
||||
const QUEUE_JOIN_TIMEOUT = 5000;
|
||||
|
||||
/**
|
||||
* Returns an identifier for a request. This is used to determine if a
|
||||
@@ -60,7 +65,7 @@ const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
|
||||
|
||||
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
|
||||
|
||||
export function enqueue(req: Request) {
|
||||
export async function enqueue(req: Request) {
|
||||
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
|
||||
let isGuest = req.user?.token === undefined;
|
||||
|
||||
@@ -92,7 +97,7 @@ export function enqueue(req: Request) {
|
||||
if (stream === "true" || stream === true || req.isStreaming) {
|
||||
const res = req.res!;
|
||||
if (!res.headersSent) {
|
||||
initStreaming(req);
|
||||
await initStreaming(req);
|
||||
}
|
||||
registerHeartbeat(req);
|
||||
} else if (getProxyLoad() > LOAD_THRESHOLD) {
|
||||
@@ -119,7 +124,9 @@ export function enqueue(req: Request) {
|
||||
if (req.retryCount ?? 0 > 0) {
|
||||
req.log.info({ retries: req.retryCount }, `Enqueued request for retry.`);
|
||||
} else {
|
||||
req.log.info(`Enqueued new request.`);
|
||||
const size = req.socket.bytesRead;
|
||||
const endpoint = req.url?.split("?")[0];
|
||||
req.log.info({ size, endpoint }, `Enqueued new request.`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,10 +196,10 @@ function processQueue() {
|
||||
reqs.filter(Boolean).forEach((req) => {
|
||||
if (req?.proceed) {
|
||||
const modelFamily = getModelFamilyForRequest(req!);
|
||||
req.log.info({
|
||||
retries: req.retryCount,
|
||||
partition: modelFamily,
|
||||
}, `Dequeuing request.`);
|
||||
req.log.info(
|
||||
{ retries: req.retryCount, partition: modelFamily },
|
||||
`Dequeuing request.`
|
||||
);
|
||||
req.proceed();
|
||||
}
|
||||
});
|
||||
@@ -327,7 +334,7 @@ export function createQueueMiddleware({
|
||||
beforeProxy?: RequestPreprocessor;
|
||||
proxyMiddleware: Handler;
|
||||
}): Handler {
|
||||
return (req, res, next) => {
|
||||
return async (req, res, next) => {
|
||||
req.proceed = async () => {
|
||||
if (beforeProxy) {
|
||||
try {
|
||||
@@ -345,7 +352,7 @@ export function createQueueMiddleware({
|
||||
};
|
||||
|
||||
try {
|
||||
enqueue(req);
|
||||
await enqueue(req);
|
||||
} catch (err: any) {
|
||||
req.res!.status(429).json({
|
||||
type: "proxy_error",
|
||||
@@ -367,8 +374,15 @@ function killQueuedRequest(req: Request) {
|
||||
try {
|
||||
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`;
|
||||
if (res.headersSent) {
|
||||
const fakeErrorEvent = buildFakeSse("proxy queue error", message, req);
|
||||
res.write(fakeErrorEvent);
|
||||
const event = makeCompletionSSE({
|
||||
format: req.inboundApi,
|
||||
title: "Proxy queue error",
|
||||
message,
|
||||
reqId: String(req.id),
|
||||
model: req.body?.model,
|
||||
});
|
||||
res.write(event);
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
res.status(500).json({ error: message });
|
||||
@@ -378,20 +392,39 @@ function killQueuedRequest(req: Request) {
|
||||
}
|
||||
}
|
||||
|
||||
function initStreaming(req: Request) {
|
||||
async function initStreaming(req: Request) {
|
||||
const res = req.res!;
|
||||
initializeSseStream(res);
|
||||
|
||||
if (req.query.badSseParser) {
|
||||
// Some clients have a broken SSE parser that doesn't handle comments
|
||||
// correctly. These clients can pass ?badSseParser=true to
|
||||
// disable comments in the SSE stream.
|
||||
res.write(getHeartbeatPayload());
|
||||
return;
|
||||
}
|
||||
const joinMsg = `: joining queue at position ${
|
||||
queue.length
|
||||
}\n\n${getHeartbeatPayload()}`;
|
||||
|
||||
res.write(`: joining queue at position ${queue.length}\n\n`);
|
||||
res.write(getHeartbeatPayload());
|
||||
let drainTimeout: NodeJS.Timeout;
|
||||
const welcome = new Promise<void>((resolve, reject) => {
|
||||
const onDrain = () => {
|
||||
clearTimeout(drainTimeout);
|
||||
req.log.debug(`Client finished consuming join message.`);
|
||||
res.off("drain", onDrain);
|
||||
resolve();
|
||||
};
|
||||
|
||||
drainTimeout = setTimeout(() => {
|
||||
res.off("drain", onDrain);
|
||||
res.destroy();
|
||||
reject(new Error("Unreponsive streaming client; killing connection"));
|
||||
}, QUEUE_JOIN_TIMEOUT);
|
||||
|
||||
if (!res.write(joinMsg)) {
|
||||
req.log.warn("Kernel buffer is full; holding client request.");
|
||||
res.once("drain", onDrain);
|
||||
} else {
|
||||
clearTimeout(drainTimeout);
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
|
||||
await welcome;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -451,14 +484,6 @@ function removeProxyMiddlewareEventListeners(req: Request) {
|
||||
export function registerHeartbeat(req: Request) {
|
||||
const res = req.res!;
|
||||
|
||||
const currentSize = getHeartbeatSize();
|
||||
req.log.debug({
|
||||
currentSize,
|
||||
HEARTBEAT_INTERVAL,
|
||||
PAYLOAD_SCALE_FACTOR,
|
||||
MAX_HEARTBEAT_SIZE,
|
||||
}, "Joining queue with heartbeat.");
|
||||
|
||||
let isBufferFull = false;
|
||||
let bufferFullCount = 0;
|
||||
req.heartbeatInterval = setInterval(() => {
|
||||
|
||||
+6
-4
@@ -4,7 +4,8 @@ import { checkRisuToken } from "./check-risu-token";
|
||||
import { openai } from "./openai";
|
||||
import { openaiImage } from "./openai-image";
|
||||
import { anthropic } from "./anthropic";
|
||||
import { googlePalm } from "./palm";
|
||||
import { googleAI } from "./google-ai";
|
||||
import { mistralAI } from "./mistral-ai";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
|
||||
@@ -18,8 +19,8 @@ proxyRouter.use((req, _res, next) => {
|
||||
next();
|
||||
});
|
||||
proxyRouter.use(
|
||||
express.json({ limit: "10mb" }),
|
||||
express.urlencoded({ extended: true, limit: "10mb" })
|
||||
express.json({ limit: "1mb" }),
|
||||
express.urlencoded({ extended: true, limit: "1mb" })
|
||||
);
|
||||
proxyRouter.use(gatekeeper);
|
||||
proxyRouter.use(checkRisuToken);
|
||||
@@ -31,7 +32,8 @@ proxyRouter.use((req, _res, next) => {
|
||||
proxyRouter.use("/openai", addV1, openai);
|
||||
proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||
proxyRouter.use("/google-palm", addV1, googlePalm);
|
||||
proxyRouter.use("/google-ai", addV1, googleAI);
|
||||
proxyRouter.use("/mistral-ai", addV1, mistralAI);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
// Redirect browser requests to the homepage.
|
||||
|
||||
+9
-3
@@ -12,7 +12,8 @@ import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir";
|
||||
import { keyPool } from "./shared/key-management";
|
||||
import { adminRouter } from "./admin/routes";
|
||||
import { proxyRouter } from "./proxy/routes";
|
||||
import { handleInfoPage } from "./info-page";
|
||||
import { handleInfoPage, renderPage } from "./info-page";
|
||||
import { buildInfo } from "./service-info";
|
||||
import { logQueue } from "./shared/prompt-logging";
|
||||
import { start as startRequestQueue } from "./proxy/queue";
|
||||
import { init as initUserStore } from "./shared/users/user-store";
|
||||
@@ -53,6 +54,10 @@ app.use(
|
||||
// a load balancer/reverse proxy, which is necessary to determine request IP
|
||||
// addresses correctly.
|
||||
app.set("trust proxy", true);
|
||||
app.use((req, _res, next) => {
|
||||
req.log.info({ ip: req.ip, forwardedFor: req.get("x-forwarded-for") });
|
||||
next();
|
||||
});
|
||||
|
||||
app.set("view engine", "ejs");
|
||||
app.set("views", [
|
||||
@@ -67,13 +72,14 @@ app.get("/health", (_req, res) => res.sendStatus(200));
|
||||
app.use(cors());
|
||||
app.use(checkOrigin);
|
||||
|
||||
// routes
|
||||
app.get("/", handleInfoPage);
|
||||
app.get("/status", (req, res) => {
|
||||
res.json(buildInfo(req.protocol + "://" + req.get("host"), false));
|
||||
});
|
||||
app.use("/admin", adminRouter);
|
||||
app.use("/proxy", proxyRouter);
|
||||
app.use("/user", userRouter);
|
||||
|
||||
// 500 and 404
|
||||
app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => {
|
||||
if (err.status) {
|
||||
res.status(err.status).json({ error: err.message });
|
||||
|
||||
@@ -0,0 +1,441 @@
|
||||
/** Calculates and returns stats about the service. */
|
||||
import { config, listConfig } from "./config";
|
||||
import {
|
||||
AnthropicKey,
|
||||
AwsBedrockKey,
|
||||
AzureOpenAIKey,
|
||||
GoogleAIKey,
|
||||
keyPool,
|
||||
OpenAIKey,
|
||||
} from "./shared/key-management";
|
||||
import {
|
||||
AnthropicModelFamily,
|
||||
assertIsKnownModelFamily,
|
||||
AwsBedrockModelFamily,
|
||||
AzureOpenAIModelFamily,
|
||||
GoogleAIModelFamily,
|
||||
LLM_SERVICES,
|
||||
LLMService,
|
||||
MistralAIModelFamily,
|
||||
MODEL_FAMILY_SERVICE,
|
||||
ModelFamily,
|
||||
OpenAIModelFamily,
|
||||
} from "./shared/models";
|
||||
import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { assertNever } from "./shared/utils";
|
||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
|
||||
|
||||
const CACHE_TTL = 2000;
|
||||
|
||||
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
|
||||
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
|
||||
k.service === "openai";
|
||||
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
|
||||
k.service === "azure";
|
||||
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||
k.service === "anthropic";
|
||||
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
||||
k.service === "google-ai";
|
||||
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
|
||||
k.service === "mistral-ai";
|
||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||
|
||||
/** Stats aggregated across all keys for a given service. */
|
||||
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
|
||||
/** Stats aggregated across all keys for a given model family. */
|
||||
type ModelAggregates = {
|
||||
active: number;
|
||||
trial?: number;
|
||||
revoked?: number;
|
||||
overQuota?: number;
|
||||
pozzed?: number;
|
||||
awsLogged?: number;
|
||||
queued: number;
|
||||
queueTime: string;
|
||||
tokens: number;
|
||||
};
|
||||
/** All possible combinations of model family and aggregate type. */
|
||||
type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`;
|
||||
|
||||
type AllStats = {
|
||||
proompts: number;
|
||||
tokens: number;
|
||||
tokenCost: number;
|
||||
} & { [modelFamily in ModelFamily]?: ModelAggregates } & {
|
||||
[service in LLMService as `${service}__${ServiceAggregate}`]?: number;
|
||||
};
|
||||
|
||||
type BaseFamilyInfo = {
|
||||
usage?: string;
|
||||
activeKeys: number;
|
||||
revokedKeys?: number;
|
||||
proomptersInQueue?: number;
|
||||
estimatedQueueTime?: string;
|
||||
};
|
||||
type OpenAIInfo = BaseFamilyInfo & {
|
||||
trialKeys?: number;
|
||||
overQuotaKeys?: number;
|
||||
};
|
||||
type AnthropicInfo = BaseFamilyInfo & { pozzedKeys?: number };
|
||||
type AwsInfo = BaseFamilyInfo & { privacy?: string };
|
||||
|
||||
// prettier-ignore
|
||||
export type ServiceInfo = {
|
||||
uptime: number;
|
||||
endpoints: {
|
||||
openai?: string;
|
||||
openai2?: string;
|
||||
"openai-image"?: string;
|
||||
anthropic?: string;
|
||||
"google-ai"?: string;
|
||||
"mistral-ai"?: string;
|
||||
aws?: string;
|
||||
azure?: string;
|
||||
};
|
||||
proompts?: number;
|
||||
tookens?: string;
|
||||
proomptersNow?: number;
|
||||
status?: string;
|
||||
config: ReturnType<typeof listConfig>;
|
||||
build: string;
|
||||
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
|
||||
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
|
||||
& { [f in AwsBedrockModelFamily]?: AwsInfo }
|
||||
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
|
||||
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
|
||||
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
|
||||
|
||||
// https://stackoverflow.com/a/66661477
|
||||
// type DeepKeyOf<T> = (
|
||||
// [T] extends [never]
|
||||
// ? ""
|
||||
// : T extends object
|
||||
// ? {
|
||||
// [K in Exclude<keyof T, symbol>]: `${K}${DotPrefix<DeepKeyOf<T[K]>>}`;
|
||||
// }[Exclude<keyof T, symbol>]
|
||||
// : ""
|
||||
// ) extends infer D
|
||||
// ? Extract<D, string>
|
||||
// : never;
|
||||
// type DotPrefix<T extends string> = T extends "" ? "" : `.${T}`;
|
||||
// type ServiceInfoPath = `{${DeepKeyOf<ServiceInfo>}}`;
|
||||
|
||||
const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
openai: {
|
||||
openai: `%BASE%/openai`,
|
||||
openai2: `%BASE%/openai/turbo-instruct`,
|
||||
"openai-image": `%BASE%/openai-image`,
|
||||
},
|
||||
anthropic: {
|
||||
anthropic: `%BASE%/anthropic`,
|
||||
},
|
||||
"google-ai": {
|
||||
"google-ai": `%BASE%/google-ai`,
|
||||
},
|
||||
"mistral-ai": {
|
||||
"mistral-ai": `%BASE%/mistral-ai`,
|
||||
},
|
||||
aws: {
|
||||
aws: `%BASE%/aws/claude`,
|
||||
},
|
||||
azure: {
|
||||
azure: `%BASE%/azure/openai`,
|
||||
},
|
||||
};
|
||||
|
||||
const modelStats = new Map<ModelAggregateKey, number>();
|
||||
const serviceStats = new Map<keyof AllStats, number>();
|
||||
|
||||
let cachedInfo: ServiceInfo | undefined;
|
||||
let cacheTime = 0;
|
||||
|
||||
export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo {
|
||||
if (cacheTime + CACHE_TTL > Date.now()) return cachedInfo!;
|
||||
|
||||
const keys = keyPool.list();
|
||||
const accessibleFamilies = new Set(
|
||||
keys
|
||||
.flatMap((k) => k.modelFamilies)
|
||||
.filter((f) => config.allowedModelFamilies.includes(f))
|
||||
.concat("turbo")
|
||||
);
|
||||
|
||||
modelStats.clear();
|
||||
serviceStats.clear();
|
||||
keys.forEach(addKeyToAggregates);
|
||||
|
||||
const endpoints = getEndpoints(baseUrl, accessibleFamilies);
|
||||
const trafficStats = getTrafficStats();
|
||||
const { serviceInfo, modelFamilyInfo } =
|
||||
getServiceModelStats(accessibleFamilies);
|
||||
const status = getStatus();
|
||||
|
||||
if (config.staticServiceInfo && !forAdmin) {
|
||||
delete trafficStats.proompts;
|
||||
delete trafficStats.tookens;
|
||||
delete trafficStats.proomptersNow;
|
||||
for (const family of Object.keys(modelFamilyInfo)) {
|
||||
assertIsKnownModelFamily(family);
|
||||
delete modelFamilyInfo[family]?.proomptersInQueue;
|
||||
delete modelFamilyInfo[family]?.estimatedQueueTime;
|
||||
delete modelFamilyInfo[family]?.usage;
|
||||
}
|
||||
}
|
||||
|
||||
return (cachedInfo = {
|
||||
uptime: Math.floor(process.uptime()),
|
||||
endpoints,
|
||||
...trafficStats,
|
||||
...serviceInfo,
|
||||
status,
|
||||
...modelFamilyInfo,
|
||||
config: listConfig(),
|
||||
build: process.env.BUILD_INFO || "dev",
|
||||
});
|
||||
}
|
||||
|
||||
function getStatus() {
|
||||
if (!config.checkKeys) return "Key checking is disabled.";
|
||||
|
||||
let unchecked = 0;
|
||||
for (const service of LLM_SERVICES) {
|
||||
unchecked += serviceStats.get(`${service}__uncheckedKeys`) || 0;
|
||||
}
|
||||
|
||||
return unchecked ? `Checking ${unchecked} keys...` : undefined;
|
||||
}
|
||||
|
||||
function getEndpoints(baseUrl: string, accessibleFamilies: Set<ModelFamily>) {
|
||||
const endpoints: Record<string, string> = {};
|
||||
for (const service of LLM_SERVICES) {
|
||||
for (const [name, url] of Object.entries(SERVICE_ENDPOINTS[service])) {
|
||||
endpoints[name] = url.replace("%BASE%", baseUrl);
|
||||
}
|
||||
|
||||
if (service === "openai" && !accessibleFamilies.has("dall-e")) {
|
||||
delete endpoints["openai-image"];
|
||||
}
|
||||
}
|
||||
return endpoints;
|
||||
}
|
||||
|
||||
type TrafficStats = Pick<ServiceInfo, "proompts" | "tookens" | "proomptersNow">;
|
||||
|
||||
function getTrafficStats(): TrafficStats {
|
||||
const tokens = serviceStats.get("tokens") || 0;
|
||||
const tokenCost = serviceStats.get("tokenCost") || 0;
|
||||
return {
|
||||
proompts: serviceStats.get("proompts") || 0,
|
||||
tookens: `${prettyTokens(tokens)}${getCostSuffix(tokenCost)}`,
|
||||
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
function getServiceModelStats(accessibleFamilies: Set<ModelFamily>) {
|
||||
const serviceInfo: {
|
||||
[s in LLMService as `${s}${"Keys" | "Orgs"}`]?: number;
|
||||
} = {};
|
||||
const modelFamilyInfo: { [f in ModelFamily]?: BaseFamilyInfo } = {};
|
||||
|
||||
for (const service of LLM_SERVICES) {
|
||||
const hasKeys = serviceStats.get(`${service}__keys`) || 0;
|
||||
if (!hasKeys) continue;
|
||||
|
||||
serviceInfo[`${service}Keys`] = hasKeys;
|
||||
accessibleFamilies.forEach((f) => {
|
||||
if (MODEL_FAMILY_SERVICE[f] === service) {
|
||||
modelFamilyInfo[f] = getInfoForFamily(f);
|
||||
}
|
||||
});
|
||||
|
||||
if (service === "openai" && config.checkKeys) {
|
||||
serviceInfo.openaiOrgs = getUniqueOpenAIOrgs(keyPool.list());
|
||||
}
|
||||
}
|
||||
return { serviceInfo, modelFamilyInfo };
|
||||
}
|
||||
|
||||
function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) {
|
||||
const orgIds = new Set(
|
||||
keys.filter((k) => k.service === "openai").map((k: any) => k.organizationId)
|
||||
);
|
||||
return orgIds.size;
|
||||
}
|
||||
|
||||
function increment<T extends keyof AllStats | ModelAggregateKey>(
|
||||
map: Map<T, number>,
|
||||
key: T,
|
||||
delta = 1
|
||||
) {
|
||||
map.set(key, (map.get(key) || 0) + delta);
|
||||
}
|
||||
|
||||
function addKeyToAggregates(k: KeyPoolKey) {
|
||||
increment(serviceStats, "proompts", k.promptCount);
|
||||
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
|
||||
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
|
||||
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
|
||||
increment(serviceStats, "mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
|
||||
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
let sumTokens = 0;
|
||||
let sumCost = 0;
|
||||
|
||||
switch (k.service) {
|
||||
case "openai":
|
||||
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
|
||||
increment(
|
||||
serviceStats,
|
||||
"openai__uncheckedKeys",
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "azure":
|
||||
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "anthropic": {
|
||||
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
|
||||
const family = "claude";
|
||||
sumTokens += k.claudeTokens;
|
||||
sumCost += getTokenCostUsd(family, k.claudeTokens);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k.claudeTokens);
|
||||
increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0);
|
||||
increment(
|
||||
serviceStats,
|
||||
"anthropic__uncheckedKeys",
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
break;
|
||||
}
|
||||
case "google-ai": {
|
||||
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
|
||||
const family = "gemini-pro";
|
||||
sumTokens += k["gemini-proTokens"];
|
||||
sumCost += getTokenCostUsd(family, k["gemini-proTokens"]);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
|
||||
break;
|
||||
}
|
||||
case "mistral-ai": {
|
||||
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "aws": {
|
||||
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
|
||||
const family = "aws-claude";
|
||||
sumTokens += k["aws-claudeTokens"];
|
||||
sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
|
||||
|
||||
// Ignore revoked keys for aws logging stats, but include keys where the
|
||||
// logging status is unknown.
|
||||
const countAsLogged =
|
||||
k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled";
|
||||
increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0);
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assertNever(k.service);
|
||||
}
|
||||
|
||||
increment(serviceStats, "tokens", sumTokens);
|
||||
increment(serviceStats, "tokenCost", sumCost);
|
||||
}
|
||||
|
||||
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(family, tokens);
|
||||
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
|
||||
activeKeys: modelStats.get(`${family}__active`) || 0,
|
||||
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||
};
|
||||
|
||||
// Add service-specific stats to the info object.
|
||||
if (config.checkKeys) {
|
||||
const service = MODEL_FAMILY_SERVICE[family];
|
||||
switch (service) {
|
||||
case "openai":
|
||||
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
|
||||
|
||||
// Delete trial/revoked keys for non-turbo families.
|
||||
// Trials are turbo 99% of the time, and if a key is invalid we don't
|
||||
// know what models it might have had assigned to it.
|
||||
if (family !== "turbo") {
|
||||
delete info.trialKeys;
|
||||
delete info.revokedKeys;
|
||||
}
|
||||
break;
|
||||
case "anthropic":
|
||||
info.pozzedKeys = modelStats.get(`${family}__pozzed`) || 0;
|
||||
break;
|
||||
case "aws":
|
||||
const logged = modelStats.get(`${family}__awsLogged`) || 0;
|
||||
if (logged > 0) {
|
||||
info.privacy = config.allowAwsLogging
|
||||
? `${logged} active keys are potentially logged.`
|
||||
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add queue stats to the info object.
|
||||
const queue = getQueueInformation(family);
|
||||
info.proomptersInQueue = queue.proomptersInQueue;
|
||||
info.estimatedQueueTime = queue.estimatedQueueTime;
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
|
||||
function getQueueInformation(partition: ModelFamily) {
|
||||
const waitMs = getEstimatedWaitTime(partition);
|
||||
const waitTime =
|
||||
waitMs < 60000
|
||||
? `${Math.round(waitMs / 1000)}sec`
|
||||
: `${Math.round(waitMs / 60000)}min, ${Math.round(
|
||||
(waitMs % 60000) / 1000
|
||||
)}sec`;
|
||||
return {
|
||||
proomptersInQueue: getQueueLength(partition),
|
||||
estimatedQueueTime: waitMs > 2000 ? waitTime : "no wait",
|
||||
};
|
||||
}
|
||||
+5
-3
@@ -1,8 +1,10 @@
|
||||
// noinspection JSUnusedGlobalSymbols,ES6UnusedImports
|
||||
|
||||
import type { HttpRequest } from "@smithy/types";
|
||||
import { Express } from "express-serve-static-core";
|
||||
import { APIFormat, Key, LLMService } from "../shared/key-management";
|
||||
import { User } from "../shared/users/schema";
|
||||
import { ModelFamily } from "../shared/models";
|
||||
import { APIFormat, Key } from "./key-management";
|
||||
import { User } from "./users/schema";
|
||||
import { LLMService, ModelFamily } from "./models";
|
||||
|
||||
declare global {
|
||||
namespace Express {
|
||||
@@ -48,20 +48,20 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
|
||||
if (error.response && AnthropicKeyChecker.errorIsAnthropicAPIError(error)) {
|
||||
const { status, data } = error.response;
|
||||
if (status === 401) {
|
||||
if (status === 401 || status === 403) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: data },
|
||||
"Key is invalid or revoked. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
} else if (status === 429) {
|
||||
}
|
||||
else if (status === 429) {
|
||||
switch (data.error.type) {
|
||||
case "rate_limit_error":
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: error.message },
|
||||
"Key is rate limited. Rechecking in 10 seconds."
|
||||
);
|
||||
0;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
break;
|
||||
|
||||
@@ -36,34 +36,10 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
||||
|
||||
protected async testKeyOrFail(key: AzureOpenAIKey) {
|
||||
const model = await this.testModel(key);
|
||||
this.log.info(
|
||||
{ key: key.hash, deploymentModel: model },
|
||||
"Checked key."
|
||||
);
|
||||
this.log.info({ key: key.hash, deploymentModel: model }, "Checked key.");
|
||||
this.updateKey(key.hash, { modelFamilies: [model] });
|
||||
}
|
||||
|
||||
// provided api-key header isn't valid (401)
|
||||
// {
|
||||
// "error": {
|
||||
// "code": "401",
|
||||
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
|
||||
// }
|
||||
// }
|
||||
|
||||
// api key correct but deployment id is wrong (404)
|
||||
// {
|
||||
// "error": {
|
||||
// "code": "DeploymentNotFound",
|
||||
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
|
||||
// }
|
||||
// }
|
||||
|
||||
// resource name is wrong (node will throw ENOTFOUND)
|
||||
|
||||
// rate limited (429)
|
||||
// TODO: try to reproduce this
|
||||
|
||||
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
|
||||
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
|
||||
const data = error.response.data;
|
||||
@@ -88,6 +64,20 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
});
|
||||
case "429":
|
||||
this.log.warn(
|
||||
{ key: key.hash, errorType, error: error.response.data },
|
||||
"Key is rate limited. Rechecking key in 1 minute."
|
||||
);
|
||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
setTimeout(async () => {
|
||||
this.log.info(
|
||||
{ key: key.hash },
|
||||
"Rechecking Azure key after rate limit."
|
||||
);
|
||||
await this.checkKey(key);
|
||||
}, 1000 * 60);
|
||||
return;
|
||||
default:
|
||||
this.log.error(
|
||||
{ key: key.hash, errorType, error: error.response.data, status },
|
||||
@@ -129,7 +119,32 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
||||
headers: { "Content-Type": "application/json", "api-key": apiKey },
|
||||
});
|
||||
|
||||
return getAzureOpenAIModelFamily(data.model);
|
||||
const family = getAzureOpenAIModelFamily(data.model);
|
||||
|
||||
// Azure returns "gpt-4" even for GPT-4 Turbo, so we need further checks.
|
||||
// Otherwise we can use the model family Azure returned.
|
||||
if (family !== "azure-gpt4") {
|
||||
return family;
|
||||
}
|
||||
|
||||
// Try to send an oversized prompt. GPT-4 Turbo can handle this but regular
|
||||
// GPT-4 will return a Bad Request error.
|
||||
const contextText = {
|
||||
max_tokens: 9000,
|
||||
stream: false,
|
||||
temperature: 0,
|
||||
seed: 0,
|
||||
messages: [{ role: "user", content: "" }],
|
||||
};
|
||||
const { data: contextTest, status } = await axios.post(url, contextText, {
|
||||
headers: { "Content-Type": "application/json", "api-key": apiKey },
|
||||
validateStatus: (status) => status === 400 || status === 200,
|
||||
});
|
||||
const code = contextTest.error?.code;
|
||||
this.log.debug({ code, status }, "Performed Azure GPT4 context size test.");
|
||||
|
||||
if (code === "context_length_exceeded") return "azure-gpt4";
|
||||
return "azure-gpt4-turbo";
|
||||
}
|
||||
|
||||
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
|
||||
|
||||
@@ -6,7 +6,6 @@ import type { AzureOpenAIModelFamily } from "../../models";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
import { OpenAIModel } from "../openai/provider";
|
||||
import { AzureOpenAIKeyChecker } from "./checker";
|
||||
import { AwsKeyChecker } from "../aws/checker";
|
||||
|
||||
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
|
||||
|
||||
|
||||
+28
-24
@@ -2,13 +2,17 @@ import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { GooglePalmModelFamily } from "../../models";
|
||||
import type { GoogleAIModelFamily } from "../../models";
|
||||
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
export type GooglePalmModel = "text-bison-001";
|
||||
// Note that Google AI is not the same as Vertex AI, both are provided by Google
|
||||
// but Vertex is the GCP product for enterprise. while Google AI is the
|
||||
// consumer-ish product. The API is different, and keys are not compatible.
|
||||
// https://ai.google.dev/docs/migrate_to_cloud
|
||||
|
||||
export type GooglePalmKeyUpdate = Omit<
|
||||
Partial<GooglePalmKey>,
|
||||
export type GoogleAIModel = "gemini-pro";
|
||||
|
||||
export type GoogleAIKeyUpdate = Omit<
|
||||
Partial<GoogleAIKey>,
|
||||
| "key"
|
||||
| "hash"
|
||||
| "lastUsed"
|
||||
@@ -17,13 +21,13 @@ export type GooglePalmKeyUpdate = Omit<
|
||||
| "rateLimitedUntil"
|
||||
>;
|
||||
|
||||
type GooglePalmKeyUsage = {
|
||||
[K in GooglePalmModelFamily as `${K}Tokens`]: number;
|
||||
type GoogleAIKeyUsage = {
|
||||
[K in GoogleAIModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
|
||||
readonly service: "google-palm";
|
||||
readonly modelFamilies: GooglePalmModelFamily[];
|
||||
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
|
||||
readonly service: "google-ai";
|
||||
readonly modelFamilies: GoogleAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
@@ -42,27 +46,27 @@ const RATE_LIMIT_LOCKOUT = 2000;
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
readonly service = "google-palm";
|
||||
export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
readonly service = "google-ai";
|
||||
|
||||
private keys: GooglePalmKey[] = [];
|
||||
private keys: GoogleAIKey[] = [];
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.googlePalmKey?.trim();
|
||||
const keyConfig = config.googleAIKey?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"GOOGLE_PALM_KEY is not set. PaLM API will not be available."
|
||||
"GOOGLE_AI_KEY is not set. Google AI API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: GooglePalmKey = {
|
||||
const newKey: GoogleAIKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["bison"],
|
||||
modelFamilies: ["gemini-pro"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
@@ -75,11 +79,11 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
bisonTokens: 0,
|
||||
"gemini-proTokens": 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys.");
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
|
||||
}
|
||||
|
||||
public init() {}
|
||||
@@ -88,10 +92,10 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(_model: GooglePalmModel) {
|
||||
public get(_model: GoogleAIModel) {
|
||||
const availableKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error("No Google PaLM keys available");
|
||||
throw new Error("No Google AI keys available");
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
@@ -122,14 +126,14 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: GooglePalmKey) {
|
||||
public disable(key: GoogleAIKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<GooglePalmKey>) {
|
||||
public update(hash: string, update: Partial<GoogleAIKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
@@ -142,7 +146,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key.bisonTokens += tokens;
|
||||
key["gemini-proTokens"] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
@@ -1,29 +1,23 @@
|
||||
import type { LLMService, ModelFamily } from "../models";
|
||||
import { OpenAIModel } from "./openai/provider";
|
||||
import { AnthropicModel } from "./anthropic/provider";
|
||||
import { GooglePalmModel } from "./palm/provider";
|
||||
import { GoogleAIModel } from "./google-ai/provider";
|
||||
import { AwsBedrockModel } from "./aws/provider";
|
||||
import { AzureOpenAIModel } from "./azure/provider";
|
||||
import { KeyPool } from "./key-pool";
|
||||
import type { ModelFamily } from "../models";
|
||||
|
||||
/** The request and response format used by a model's API. */
|
||||
export type APIFormat =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-palm"
|
||||
| "google-ai"
|
||||
| "mistral-ai"
|
||||
| "openai-text"
|
||||
| "openai-image";
|
||||
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-palm"
|
||||
| "aws"
|
||||
| "azure";
|
||||
export type Model =
|
||||
| OpenAIModel
|
||||
| AnthropicModel
|
||||
| GooglePalmModel
|
||||
| GoogleAIModel
|
||||
| AwsBedrockModel
|
||||
| AzureOpenAIModel;
|
||||
|
||||
@@ -77,6 +71,6 @@ export interface KeyProvider<T extends Key = Key> {
|
||||
export const keyPool = new KeyPool();
|
||||
export { AnthropicKey } from "./anthropic/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GooglePalmKey } from "./palm/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
|
||||
@@ -4,14 +4,14 @@ import os from "os";
|
||||
import schedule from "node-schedule";
|
||||
import { config } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
import { Key, Model, KeyProvider, LLMService } from "./index";
|
||||
import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models";
|
||||
import { Key, Model, KeyProvider } from "./index";
|
||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { GooglePalmKeyProvider } from "./palm/provider";
|
||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { ModelFamily } from "../models";
|
||||
import { assertNever } from "../utils";
|
||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||
import { MistralAIKeyProvider } from "./mistral-ai/provider";
|
||||
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||
|
||||
@@ -24,7 +24,8 @@ export class KeyPool {
|
||||
constructor() {
|
||||
this.keyProviders.push(new OpenAIKeyProvider());
|
||||
this.keyProviders.push(new AnthropicKeyProvider());
|
||||
this.keyProviders.push(new GooglePalmKeyProvider());
|
||||
this.keyProviders.push(new GoogleAIKeyProvider());
|
||||
this.keyProviders.push(new MistralAIKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||
}
|
||||
@@ -82,7 +83,7 @@ export class KeyPool {
|
||||
}
|
||||
|
||||
public getLockoutPeriod(family: ModelFamily): number {
|
||||
const service = this.getServiceForModelFamily(family);
|
||||
const service = MODEL_FAMILY_SERVICE[family];
|
||||
return this.getKeyProvider(service).getLockoutPeriod(family);
|
||||
}
|
||||
|
||||
@@ -119,9 +120,12 @@ export class KeyPool {
|
||||
} else if (model.startsWith("claude-")) {
|
||||
// https://console.anthropic.com/docs/api/reference#parameters
|
||||
return "anthropic";
|
||||
} else if (model.includes("bison")) {
|
||||
} else if (model.includes("gemini")) {
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
return "google-palm";
|
||||
return "google-ai";
|
||||
} else if (model.includes("mistral")) {
|
||||
// https://docs.mistral.ai/platform/endpoints
|
||||
return "mistral-ai";
|
||||
} else if (model.startsWith("anthropic.claude")) {
|
||||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
@@ -132,30 +136,6 @@ export class KeyPool {
|
||||
throw new Error(`Unknown service for model '${model}'`);
|
||||
}
|
||||
|
||||
private getServiceForModelFamily(modelFamily: ModelFamily): LLMService {
|
||||
switch (modelFamily) {
|
||||
case "gpt4":
|
||||
case "gpt4-32k":
|
||||
case "gpt4-turbo":
|
||||
case "turbo":
|
||||
case "dall-e":
|
||||
return "openai";
|
||||
case "claude":
|
||||
return "anthropic";
|
||||
case "bison":
|
||||
return "google-palm";
|
||||
case "aws-claude":
|
||||
return "aws";
|
||||
case "azure-turbo":
|
||||
case "azure-gpt4":
|
||||
case "azure-gpt4-32k":
|
||||
case "azure-gpt4-turbo":
|
||||
return "azure";
|
||||
default:
|
||||
assertNever(modelFamily);
|
||||
}
|
||||
}
|
||||
|
||||
private getKeyProvider(service: LLMService): KeyProvider {
|
||||
return this.keyProviders.find((provider) => provider.service === service)!;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import axios, { AxiosError } from "axios";
|
||||
import type { MistralAIModelFamily, OpenAIModelFamily } from "../../models";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { MistralAIKey, MistralAIKeyProvider } from "./provider";
|
||||
import { getMistralAIModelFamily, getOpenAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
|
||||
const GET_MODELS_URL = "https://api.mistral.ai/v1/models";
|
||||
|
||||
type GetModelsResponse = {
|
||||
data: [{ id: string }];
|
||||
};
|
||||
|
||||
type MistralAIError = {
|
||||
message: string;
|
||||
request_id: string;
|
||||
};
|
||||
|
||||
type UpdateFn = typeof MistralAIKeyProvider.prototype.update;
|
||||
|
||||
export class MistralAIKeyChecker extends KeyCheckerBase<MistralAIKey> {
|
||||
constructor(keys: MistralAIKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "mistral-ai",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: MistralAIKey) {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
const provisionedModels = await this.getProvisionedModels(key);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
}
|
||||
this.log.info({ key: key.hash, models: key.modelFamilies }, "Checked key.");
|
||||
}
|
||||
|
||||
private async getProvisionedModels(
|
||||
key: MistralAIKey
|
||||
): Promise<MistralAIModelFamily[]> {
|
||||
const opts = { headers: MistralAIKeyChecker.getHeaders(key) };
|
||||
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
|
||||
const models = data.data;
|
||||
|
||||
const families = new Set<MistralAIModelFamily>();
|
||||
models.forEach(({ id }) => families.add(getMistralAIModelFamily(id)));
|
||||
|
||||
// We want to update the key's model families here, but we don't want to
|
||||
// update its `lastChecked` timestamp because we need to let the liveness
|
||||
// check run before we can consider the key checked.
|
||||
|
||||
const familiesArray = [...families];
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
|
||||
this.updateKey(key.hash, {
|
||||
modelFamilies: familiesArray,
|
||||
lastChecked: keyFromPool.lastChecked,
|
||||
});
|
||||
return familiesArray;
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: MistralAIKey, error: AxiosError) {
|
||||
if (error.response && MistralAIKeyChecker.errorIsMistralAIError(error)) {
|
||||
const { status, data } = error.response;
|
||||
if (status === 401) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: data },
|
||||
"Key is invalid or revoked. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
modelFamilies: ["mistral-tiny"],
|
||||
});
|
||||
} else {
|
||||
this.log.error(
|
||||
{ key: key.hash, status, error: data },
|
||||
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
|
||||
);
|
||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
return;
|
||||
}
|
||||
this.log.error(
|
||||
{ key: key.hash, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
static errorIsMistralAIError(
|
||||
error: AxiosError
|
||||
): error is AxiosError<MistralAIError> {
|
||||
const data = error.response?.data as any;
|
||||
return data?.message && data?.request_id;
|
||||
}
|
||||
|
||||
static getHeaders(key: MistralAIKey) {
|
||||
return {
|
||||
Authorization: `Bearer ${key.key}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider, Model } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
|
||||
import { MistralAIKeyChecker } from "./checker";
|
||||
|
||||
export type MistralAIModel =
|
||||
| "mistral-tiny"
|
||||
| "mistral-small"
|
||||
| "mistral-medium";
|
||||
|
||||
export type MistralAIKeyUpdate = Omit<
|
||||
Partial<MistralAIKey>,
|
||||
| "key"
|
||||
| "hash"
|
||||
| "lastUsed"
|
||||
| "promptCount"
|
||||
| "rateLimitedAt"
|
||||
| "rateLimitedUntil"
|
||||
>;
|
||||
|
||||
type MistralAIKeyUsage = {
|
||||
[K in MistralAIModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface MistralAIKey extends Key, MistralAIKeyUsage {
|
||||
readonly service: "mistral-ai";
|
||||
readonly modelFamilies: MistralAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 2000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
|
||||
readonly service = "mistral-ai";
|
||||
|
||||
private keys: MistralAIKey[] = [];
|
||||
private checker?: MistralAIKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.mistralAIKey?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"MISTRAL_AI_KEY is not set. Mistral AI API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: MistralAIKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["mistral-tiny", "mistral-small", "mistral-medium"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
hash: `mst-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
"mistral-tinyTokens": 0,
|
||||
"mistral-smallTokens": 0,
|
||||
"mistral-mediumTokens": 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Mistral AI keys.");
|
||||
}
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
const updateFn = this.update.bind(this);
|
||||
this.checker = new MistralAIKeyChecker(this.keys, updateFn);
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(_model: Model) {
|
||||
const availableKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error("No Mistral AI keys available");
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: MistralAIKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<MistralAIKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
const family = getMistralAIModelFamily(model);
|
||||
key[`${family}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
* concurrent requests running on this key. We don't have any information on
|
||||
* when these requests will resolve, so all we can do is wait a bit and try
|
||||
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||
}
|
||||
|
||||
public recheck() {}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
||||
+79
-12
@@ -1,8 +1,20 @@
|
||||
// Don't import anything here, this is imported by config.ts
|
||||
// Don't import any other project files here as this is one of the first modules
|
||||
// loaded and it will cause circular imports.
|
||||
|
||||
import pino from "pino";
|
||||
import type { Request } from "express";
|
||||
import { assertNever } from "./utils";
|
||||
|
||||
/**
|
||||
* The service that a model is hosted on. Distinct from `APIFormat` because some
|
||||
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
|
||||
*/
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-ai"
|
||||
| "mistral-ai"
|
||||
| "aws"
|
||||
| "azure";
|
||||
|
||||
export type OpenAIModelFamily =
|
||||
| "turbo"
|
||||
@@ -11,7 +23,11 @@ export type OpenAIModelFamily =
|
||||
| "gpt4-turbo"
|
||||
| "dall-e";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type GooglePalmModelFamily = "bison";
|
||||
export type GoogleAIModelFamily = "gemini-pro";
|
||||
export type MistralAIModelFamily =
|
||||
| "mistral-tiny"
|
||||
| "mistral-small"
|
||||
| "mistral-medium";
|
||||
export type AwsBedrockModelFamily = "aws-claude";
|
||||
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
OpenAIModelFamily,
|
||||
@@ -20,7 +36,8 @@ export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GooglePalmModelFamily
|
||||
| GoogleAIModelFamily
|
||||
| MistralAIModelFamily
|
||||
| AwsBedrockModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
|
||||
@@ -33,7 +50,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"gpt4-turbo",
|
||||
"dall-e",
|
||||
"claude",
|
||||
"bison",
|
||||
"gemini-pro",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
@@ -41,6 +61,17 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"azure-gpt4-turbo",
|
||||
] as const);
|
||||
|
||||
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||
arr: A & ([LLMService] extends [A[number]] ? unknown : never)
|
||||
) => arr)([
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google-ai",
|
||||
"mistral-ai",
|
||||
"aws",
|
||||
"azure",
|
||||
] as const);
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
|
||||
"^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo",
|
||||
@@ -53,7 +84,27 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^dall-e-\\d{1}$": "dall-e",
|
||||
};
|
||||
|
||||
const modelLogger = pino({ level: "debug" }).child({ module: "startup" });
|
||||
export const MODEL_FAMILY_SERVICE: {
|
||||
[f in ModelFamily]: LLMService;
|
||||
} = {
|
||||
turbo: "openai",
|
||||
gpt4: "openai",
|
||||
"gpt4-turbo": "openai",
|
||||
"gpt4-32k": "openai",
|
||||
"dall-e": "openai",
|
||||
claude: "anthropic",
|
||||
"aws-claude": "aws",
|
||||
"azure-turbo": "azure",
|
||||
"azure-gpt4": "azure",
|
||||
"azure-gpt4-32k": "azure",
|
||||
"azure-gpt4-turbo": "azure",
|
||||
"gemini-pro": "google-ai",
|
||||
"mistral-tiny": "mistral-ai",
|
||||
"mistral-small": "mistral-ai",
|
||||
"mistral-medium": "mistral-ai",
|
||||
};
|
||||
|
||||
pino({ level: "debug" }).child({ module: "startup" });
|
||||
|
||||
export function getOpenAIModelFamily(
|
||||
model: string,
|
||||
@@ -70,10 +121,19 @@ export function getClaudeModelFamily(model: string): ModelFamily {
|
||||
return "claude";
|
||||
}
|
||||
|
||||
export function getGooglePalmModelFamily(model: string): ModelFamily {
|
||||
if (model.match(/^\w+-bison-\d{3}$/)) return "bison";
|
||||
modelLogger.warn({ model }, "Could not determine Google PaLM model family");
|
||||
return "bison";
|
||||
export function getGoogleAIModelFamily(_model: string): ModelFamily {
|
||||
return "gemini-pro";
|
||||
}
|
||||
|
||||
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
|
||||
switch (model) {
|
||||
case "mistral-tiny":
|
||||
case "mistral-small":
|
||||
case "mistral-medium":
|
||||
return model;
|
||||
default:
|
||||
return "mistral-tiny";
|
||||
}
|
||||
}
|
||||
|
||||
export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
||||
@@ -130,8 +190,11 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
case "openai-image":
|
||||
modelFamily = getOpenAIModelFamily(model);
|
||||
break;
|
||||
case "google-palm":
|
||||
modelFamily = getGooglePalmModelFamily(model);
|
||||
case "google-ai":
|
||||
modelFamily = getGoogleAIModelFamily(model);
|
||||
break;
|
||||
case "mistral-ai":
|
||||
modelFamily = getMistralAIModelFamily(model);
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
@@ -140,3 +203,7 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
|
||||
return (req.modelFamily = modelFamily);
|
||||
}
|
||||
|
||||
function assertNever(x: never): never {
|
||||
throw new Error(`Called assertNever with argument ${x}.`);
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { config } from "../config";
|
||||
import { ModelFamily } from "./models";
|
||||
|
||||
// technically slightly underestimates, because completion tokens cost more
|
||||
@@ -24,6 +25,15 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
||||
case "claude":
|
||||
cost = 0.00001102;
|
||||
break;
|
||||
case "mistral-tiny":
|
||||
cost = 0.00000031;
|
||||
break;
|
||||
case "mistral-small":
|
||||
cost = 0.00000132;
|
||||
break;
|
||||
case "mistral-medium":
|
||||
cost = 0.0000055;
|
||||
break;
|
||||
}
|
||||
return cost * Math.max(0, tokens);
|
||||
}
|
||||
@@ -40,3 +50,8 @@ export function prettyTokens(tokens: number): string {
|
||||
return (tokens / 1000000000).toFixed(3) + "b";
|
||||
}
|
||||
}
|
||||
|
||||
export function getCostSuffix(cost: number) {
|
||||
if (!config.showTokenCosts) return "";
|
||||
return ` ($${cost.toFixed(2)})`;
|
||||
}
|
||||
|
||||
+55
-24
@@ -1,6 +1,7 @@
|
||||
import { Request, Response } from "express";
|
||||
import { Response } from "express";
|
||||
import { IncomingMessage } from "http";
|
||||
import { assertNever } from "./utils";
|
||||
import { APIFormat } from "./key-management";
|
||||
|
||||
export function initializeSseStream(res: Response) {
|
||||
res.statusCode = 200;
|
||||
@@ -39,54 +40,84 @@ export function copySseResponseHeaders(
|
||||
* that the request is being proxied to. Used to send error messages to the
|
||||
* client in the middle of a streaming request.
|
||||
*/
|
||||
export function buildFakeSse(type: string, string: string, req: Request) {
|
||||
let fakeEvent;
|
||||
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`;
|
||||
export function makeCompletionSSE({
|
||||
format,
|
||||
title,
|
||||
message,
|
||||
obj,
|
||||
reqId,
|
||||
model = "unknown",
|
||||
}: {
|
||||
format: APIFormat;
|
||||
title: string;
|
||||
message: string;
|
||||
obj?: object;
|
||||
reqId: string | number | object;
|
||||
model?: string;
|
||||
}) {
|
||||
const id = String(reqId);
|
||||
const content = `\n\n**${title}**\n${message}${
|
||||
obj ? `\n\`\`\`\n${JSON.stringify(obj, null, 2)}\n\`\`\`\n` : ""
|
||||
}`;
|
||||
|
||||
switch (req.inboundApi) {
|
||||
let event;
|
||||
|
||||
switch (format) {
|
||||
case "openai":
|
||||
fakeEvent = {
|
||||
id: "chatcmpl-" + req.id,
|
||||
case "mistral-ai":
|
||||
event = {
|
||||
id: "chatcmpl-" + id,
|
||||
object: "chat.completion.chunk",
|
||||
created: Date.now(),
|
||||
model: req.body?.model,
|
||||
choices: [{ delta: { content }, index: 0, finish_reason: type }],
|
||||
model,
|
||||
choices: [{ delta: { content }, index: 0, finish_reason: title }],
|
||||
};
|
||||
break;
|
||||
case "openai-text":
|
||||
fakeEvent = {
|
||||
id: "cmpl-" + req.id,
|
||||
event = {
|
||||
id: "cmpl-" + id,
|
||||
object: "text_completion",
|
||||
created: Date.now(),
|
||||
choices: [
|
||||
{ text: content, index: 0, logprobs: null, finish_reason: type },
|
||||
{ text: content, index: 0, logprobs: null, finish_reason: title },
|
||||
],
|
||||
model: req.body?.model,
|
||||
model,
|
||||
};
|
||||
break;
|
||||
case "anthropic":
|
||||
fakeEvent = {
|
||||
event = {
|
||||
completion: content,
|
||||
stop_reason: type,
|
||||
truncated: false, // I've never seen this be true
|
||||
stop_reason: title,
|
||||
truncated: false,
|
||||
stop: null,
|
||||
model: req.body?.model,
|
||||
log_id: "proxy-req-" + req.id,
|
||||
model,
|
||||
log_id: "proxy-req-" + id,
|
||||
};
|
||||
break;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
return JSON.stringify({
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: content }], role: "model" },
|
||||
finishReason: title,
|
||||
index: 0,
|
||||
tokenCount: null,
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
});
|
||||
case "openai-image":
|
||||
throw new Error(`SSE not supported for ${req.inboundApi} requests`);
|
||||
throw new Error(`SSE not supported for ${format} requests`);
|
||||
default:
|
||||
assertNever(req.inboundApi);
|
||||
assertNever(format);
|
||||
}
|
||||
|
||||
if (req.inboundApi === "anthropic") {
|
||||
if (format === "anthropic") {
|
||||
return (
|
||||
["event: completion", `data: ${JSON.stringify(fakeEvent)}`].join("\n") +
|
||||
["event: completion", `data: ${JSON.stringify(event)}`].join("\n") +
|
||||
"\n\n"
|
||||
);
|
||||
}
|
||||
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
return `data: ${JSON.stringify(event)}\n\n`;
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,45 @@
|
||||
import { MistralAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload.js";
|
||||
import * as tokenizer from "./mistral-tokenizer-js";
|
||||
|
||||
export function init() {
|
||||
tokenizer.initializemistralTokenizer();
|
||||
return true;
|
||||
}
|
||||
|
||||
export function getTokenCount(prompt: MistralAIChatMessage[] | string) {
|
||||
if (typeof prompt === "string") {
|
||||
return getTextTokenCount(prompt);
|
||||
}
|
||||
|
||||
let chunks = [];
|
||||
for (const message of prompt) {
|
||||
switch (message.role) {
|
||||
case "system":
|
||||
chunks.push(message.content);
|
||||
break;
|
||||
case "assistant":
|
||||
chunks.push(message.content + "</s>");
|
||||
break;
|
||||
case "user":
|
||||
chunks.push("[INST] " + message.content + " [/INST]");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return getTextTokenCount(chunks.join(" "));
|
||||
}
|
||||
|
||||
function getTextTokenCount(prompt: string) {
|
||||
// Don't try tokenizing if the prompt is massive to prevent DoS.
|
||||
// 500k characters should be sufficient for all supported models.
|
||||
if (prompt.length > 500000) {
|
||||
return {
|
||||
tokenizer: "length fallback",
|
||||
token_count: 100000,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
tokenizer: "mistral-tokenizer-js",
|
||||
token_count: tokenizer.encode(prompt.normalize("NFKC"))!.length,
|
||||
};
|
||||
}
|
||||
@@ -2,7 +2,11 @@ import { Tiktoken } from "tiktoken/lite";
|
||||
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
|
||||
import { logger } from "../../logger";
|
||||
import { libSharp } from "../file-storage";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import type {
|
||||
GoogleAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import { z } from "zod";
|
||||
|
||||
const log = logger.child({ module: "tokenizer", service: "openai" });
|
||||
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
|
||||
@@ -29,11 +33,11 @@ export async function getTokenCount(
|
||||
return getTextTokenCount(prompt);
|
||||
}
|
||||
|
||||
const gpt4 = model.startsWith("gpt-4");
|
||||
const oldFormatting = model.startsWith("turbo-0301");
|
||||
const vision = model.includes("vision");
|
||||
|
||||
const tokensPerMessage = gpt4 ? 3 : 4;
|
||||
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
|
||||
const tokensPerMessage = oldFormatting ? 4 : 3;
|
||||
const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
|
||||
|
||||
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
|
||||
|
||||
@@ -228,3 +232,24 @@ export function getOpenAIImageCost(params: {
|
||||
token_count: Math.ceil(tokens),
|
||||
};
|
||||
}
|
||||
|
||||
export function estimateGoogleAITokenCount(prompt: string | GoogleAIChatMessage[]) {
|
||||
if (typeof prompt === "string") {
|
||||
return getTextTokenCount(prompt);
|
||||
}
|
||||
|
||||
const tokensPerMessage = 3;
|
||||
|
||||
let numTokens = 0;
|
||||
for (const message of prompt) {
|
||||
numTokens += tokensPerMessage;
|
||||
numTokens += encoder.encode(message.parts[0].text).length;
|
||||
}
|
||||
|
||||
numTokens += 3;
|
||||
|
||||
return {
|
||||
tokenizer: "tiktoken (google-ai estimate)",
|
||||
token_count: numTokens,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import { Request } from "express";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import type {
|
||||
GoogleAIChatMessage,
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import { assertNever } from "../utils";
|
||||
import {
|
||||
init as initClaude,
|
||||
@@ -9,12 +13,18 @@ import {
|
||||
init as initOpenAi,
|
||||
getTokenCount as getOpenAITokenCount,
|
||||
getOpenAIImageCost,
|
||||
estimateGoogleAITokenCount,
|
||||
} from "./openai";
|
||||
import {
|
||||
init as initMistralAI,
|
||||
getTokenCount as getMistralAITokenCount,
|
||||
} from "./mistral";
|
||||
import { APIFormat } from "../key-management";
|
||||
|
||||
export async function init() {
|
||||
initClaude();
|
||||
initOpenAi();
|
||||
initMistralAI();
|
||||
}
|
||||
|
||||
/** Tagged union via `service` field of the different types of requests that can
|
||||
@@ -24,8 +34,10 @@ type TokenCountRequest = { req: Request } & (
|
||||
| {
|
||||
prompt: string;
|
||||
completion?: never;
|
||||
service: "openai-text" | "anthropic" | "google-palm";
|
||||
service: "openai-text" | "anthropic" | "google-ai";
|
||||
}
|
||||
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
|
||||
| { prompt: MistralAIChatMessage[]; completion?: never; service: "mistral-ai" }
|
||||
| { prompt?: never; completion: string; service: APIFormat }
|
||||
| { prompt?: never; completion?: never; service: "openai-image" }
|
||||
);
|
||||
@@ -65,11 +77,16 @@ export async function countTokens({
|
||||
}),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "google-palm":
|
||||
// TODO: Can't find a tokenization library for PaLM. There is an API
|
||||
case "google-ai":
|
||||
// TODO: Can't find a tokenization library for Gemini. There is an API
|
||||
// endpoint for it but it adds significant latency to the request.
|
||||
return {
|
||||
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
|
||||
...estimateGoogleAITokenCount(prompt ?? (completion || [])),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "mistral-ai":
|
||||
return {
|
||||
...getMistralAITokenCount(prompt ?? completion),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
default:
|
||||
|
||||
@@ -9,7 +9,7 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
|
||||
"gpt4-turbo": z.number().optional().default(0),
|
||||
"dall-e": z.number().optional().default(0),
|
||||
claude: z.number().optional().default(0),
|
||||
bison: z.number().optional().default(0),
|
||||
"gemini-pro": z.number().optional().default(0),
|
||||
"aws-claude": z.number().optional().default(0),
|
||||
});
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ import { config, getFirebaseApp } from "../../config";
|
||||
import {
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGooglePalmModelFamily,
|
||||
getGoogleAIModelFamily,
|
||||
getMistralAIModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
MODEL_FAMILIES,
|
||||
ModelFamily,
|
||||
@@ -33,7 +34,10 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
||||
"gpt4-turbo": 0,
|
||||
"dall-e": 0,
|
||||
claude: 0,
|
||||
bison: 0,
|
||||
"gemini-pro": 0,
|
||||
"mistral-tiny": 0,
|
||||
"mistral-small": 0,
|
||||
"mistral-medium": 0,
|
||||
"aws-claude": 0,
|
||||
"azure-turbo": 0,
|
||||
"azure-gpt4": 0,
|
||||
@@ -397,8 +401,10 @@ function getModelFamilyForQuotaUsage(
|
||||
return getOpenAIModelFamily(model);
|
||||
case "anthropic":
|
||||
return getClaudeModelFamily(model);
|
||||
case "google-palm":
|
||||
return getGooglePalmModelFamily(model);
|
||||
case "google-ai":
|
||||
return getGoogleAIModelFamily(model);
|
||||
case "mistral-ai":
|
||||
return getMistralAIModelFamily(model);
|
||||
default:
|
||||
assertNever(api);
|
||||
}
|
||||
|
||||
+1
-1
@@ -15,5 +15,5 @@
|
||||
},
|
||||
"include": ["src"],
|
||||
"exclude": ["node_modules"],
|
||||
"files": ["src/types/custom.d.ts"]
|
||||
"files": ["src/shared/custom.d.ts"]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user