diff --git a/src/config.ts b/src/config.ts index 939e6c5..860aa75 100644 --- a/src/config.ts +++ b/src/config.ts @@ -10,8 +10,10 @@ export type DequeueMode = "fair" | "random" | "none"; type Config = { /** The port the proxy server will listen on. */ port: number; - /** OpenAI API key, either a single key or a comma-delimeted list of keys. */ + /** Comma-delimited list of OpenAI API keys. */ openaiKey?: string; + /** Comma-delimited list of Anthropic API keys. */ + anthropicKey?: string; /** * The proxy key to require for requests. Only applicable if the user * management mode is set to 'proxy_key', and required if so. @@ -118,6 +120,7 @@ type Config = { export const config: Config = { port: getEnvWithDefault("PORT", 7860), openaiKey: getEnvWithDefault("OPENAI_KEY", ""), + anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""), adminKey: getEnvWithDefault("ADMIN_KEY", ""), gatekeeper: getEnvWithDefault("GATEKEEPER", "none"), @@ -221,6 +224,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [ "port", "logLevel", "openaiKey", + "anthropicKey", "proxyKey", "adminKey", "checkKeys", @@ -265,7 +269,7 @@ function getEnvWithDefault(name: string, defaultValue: T): T { return defaultValue; } try { - if (name === "OPENAI_KEY") { + if (name === "OPENAI_KEY" || name === "ANTHROPIC_KEY") { return value as unknown as T; } return JSON.parse(value) as T; diff --git a/src/info-page.ts b/src/info-page.ts index 481b300..50374cf 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -27,40 +27,52 @@ export const handleInfoPage = (req: Request, res: Response) => { function cacheInfoPageHtml(host: string) { const keys = keyPool.list(); let keyInfo: Record = { all: keys.length }; + + const openAIKeys = keys.filter((k) => k.service === "openai"); + const anthropicKeys = keys.filter((k) => k.service === "anthropic"); + + let anthropicInfo: Record = { + all: anthropicKeys.length, + active: anthropicKeys.filter((k) => !k.isDisabled).length, + }; + let openAIInfo: Record = { + all: openAIKeys.length, + active: openAIKeys.filter((k) => !k.isDisabled).length, + }; if (keyPool.anyUnchecked()) { const uncheckedKeys = keys.filter((k) => !k.lastChecked); - keyInfo = { - ...keyInfo, + openAIInfo = { + ...openAIInfo, active: keys.filter((k) => !k.isDisabled).length, status: `Still checking ${uncheckedKeys.length} keys...`, }; } else if (config.checkKeys) { - const trialKeys = keys.filter((k) => k.isTrial); - const turboKeys = keys.filter((k) => !k.isGpt4 && !k.isDisabled); - const gpt4Keys = keys.filter((k) => k.isGpt4 && !k.isDisabled); + const trialKeys = openAIKeys.filter((k) => k.isTrial); + const turboKeys = openAIKeys.filter((k) => !k.isGpt4 && !k.isDisabled); + const gpt4Keys = openAIKeys.filter((k) => k.isGpt4 && !k.isDisabled); const quota: Record = { turbo: "", gpt4: "" }; - const hasGpt4 = keys.some((k) => k.isGpt4); + const hasGpt4 = openAIKeys.some((k) => k.isGpt4); + const turboQuota = keyPool.remainingQuota("openai") * 100; + const gpt4Quota = keyPool.remainingQuota("openai", { gpt4: true }) * 100; if (config.quotaDisplayMode === "full") { - quota.turbo = `${keyPool.usageInUsd()} (${Math.round( - keyPool.remainingQuota() * 100 - )}% remaining)`; - quota.gpt4 = `${keyPool.usageInUsd(true)} (${Math.round( - keyPool.remainingQuota(true) * 100 - )}% remaining)`; + const turboUsage = keyPool.usageInUsd("openai"); + const gpt4Usage = keyPool.usageInUsd("openai", { gpt4: true }); + quota.turbo = `${turboUsage} (${Math.round(turboQuota)}% remaining)`; + quota.gpt4 = `${gpt4Usage} (${Math.round(gpt4Quota)}% remaining)`; } else { - quota.turbo = `${Math.round(keyPool.remainingQuota() * 100)}%`; - quota.gpt4 = `${Math.round(keyPool.remainingQuota(true) * 100)}%`; + quota.turbo = `${Math.round(turboQuota)}%`; + quota.gpt4 = `${Math.round(gpt4Quota * 100)}%`; } if (!hasGpt4) { delete quota.gpt4; } - keyInfo = { - ...keyInfo, + openAIInfo = { + ...openAIInfo, trial: trialKeys.length, active: { turbo: turboKeys.length, @@ -70,6 +82,11 @@ function cacheInfoPageHtml(host: string) { }; } + keyInfo = { + ...(openAIKeys.length ? { openai: openAIInfo } : {}), + ...(anthropicKeys.length ? { anthropic: anthropicInfo } : {}), + }; + const info = { uptime: process.uptime(), endpoints: { diff --git a/src/key-management/anthropic/provider.ts b/src/key-management/anthropic/provider.ts new file mode 100644 index 0000000..e878178 --- /dev/null +++ b/src/key-management/anthropic/provider.ts @@ -0,0 +1,188 @@ +import crypto from "crypto"; +import { Key, KeyProvider } from ".."; +import { config } from "../../config"; +import { logger } from "../../logger"; + +export const ANTHROPIC_SUPPORTED_MODELS = [ + "claude-instant-v1", + "claude-instant-v1-100k", + "claude-v1", + "claude-v1-100k", +] as const; +export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number]; + +export interface AnthropicKey extends Key { + readonly service: "anthropic"; + /** The time at which this key was last rate limited. */ + rateLimitedAt: number; + /** The time until which this key is rate limited. */ + rateLimitedUntil: number; +} + +/** + * We don't get rate limit headers from Anthropic so if we get a 429, we just + * lock out the key for 10 seconds. + */ +const RATE_LIMIT_LOCKOUT = 10000; + +export class AnthropicKeyProvider implements KeyProvider { + readonly service = "anthropic"; + + private keys: AnthropicKey[] = []; + private log = logger.child({ module: "key-provider", service: this.service }); + + constructor() { + const keyConfig = config.anthropicKey?.trim(); + if (!keyConfig) { + this.log.warn( + "ANTHROPIC_KEY is not set. Anthropic API will not be available." + ); + return; + } + let bareKeys: string[]; + bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; + for (const key of bareKeys) { + const newKey: AnthropicKey = { + key, + service: this.service, + isGpt4: false, + isTrial: false, + isDisabled: false, + promptCount: 0, + lastUsed: 0, + rateLimitedAt: 0, + rateLimitedUntil: 0, + hash: `ant-${crypto + .createHash("sha256") + .update(key) + .digest("hex") + .slice(0, 8)}`, + lastChecked: 0, + }; + this.keys.push(newKey); + } + this.log.info({ keyCount: this.keys.length }, "Loaded Anthropic keys."); + } + + public init() { + // Nothing to do as Anthropic's API doesn't provide any usage information so + // there is no key checker implementation and no need to start it. + } + + public list() { + return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); + } + + public get(_model: AnthropicModel) { + // Currently, all Anthropic keys have access to all models. This will almost + // certainly change when they move out of beta later this year. + const availableKeys = this.keys.filter((k) => !k.isDisabled); + if (availableKeys.length === 0) { + throw new Error("No Anthropic keys available."); + } + + // (largely copied from the OpenAI provider, without trial key support) + // Select a key, from highest priority to lowest priority: + // 1. Keys which are not rate limited + // a. If all keys were rate limited recently, select the least-recently + // rate limited key. + // 2. Keys which have not been used in the longest time + + const now = Date.now(); + + const keysByPriority = availableKeys.sort((a, b) => { + const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT; + const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT; + + if (aRateLimited && !bRateLimited) return 1; + if (!aRateLimited && bRateLimited) return -1; + if (aRateLimited && bRateLimited) { + return a.rateLimitedAt - b.rateLimitedAt; + } + return a.lastUsed - b.lastUsed; + }); + + const selectedKey = keysByPriority[0]; + selectedKey.lastUsed = now; + selectedKey.rateLimitedAt = now; + // Intended to throttle the queue processor as otherwise it will just + // flood the API with requests and we want to wait a sec to see if we're + // going to get a rate limit error on this key. + selectedKey.rateLimitedUntil = now + 1000; + return { ...selectedKey }; + } + + public disable(key: AnthropicKey) { + const keyFromPool = this.keys.find((k) => k.key === key.key); + if (!keyFromPool || keyFromPool.isDisabled) return; + keyFromPool.isDisabled = true; + this.log.warn({ key: key.hash }, "Key disabled"); + } + + public available() { + return this.keys.filter((k) => !k.isDisabled).length; + } + + // No key checker for Anthropic + public anyUnchecked() { + return false; + } + + public incrementPrompt(hash?: string) { + const key = this.keys.find((k) => k.hash === hash); + if (!key) return; + key.promptCount++; + } + + public getLockoutPeriod(_model: AnthropicModel) { + const activeKeys = this.keys.filter((k) => !k.isDisabled); + // Don't lock out if there are no keys available or the queue will stall. + // Just let it through so the add-key middleware can throw an error. + if (activeKeys.length === 0) return 0; + + const now = Date.now(); + const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil); + const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length; + + if (anyNotRateLimited) return 0; + + // If all keys are rate-limited, return the time until the first key is + // ready. + const timeUntilFirstReady = Math.min( + ...activeKeys.map((k) => k.rateLimitedUntil - now) + ); + return timeUntilFirstReady; + } + + /** + * This is called when we receive a 429, which means there are already five + * concurrent requests running on this key. We don't have any information on + * when these requests will resolve so all we can do is wait a bit and try + * again. + * We will lock the key for 10 seconds, which should let a few of the other + * generations finish. This is an arbitrary number but the goal is to balance + * between not hammering the API with requests and not locking out a key that + * is actually available. + * TODO; Try to assign requests to slots on each key so we have an idea of how + * long each slot has been running and can make a more informed decision on + * how long to lock the key. + */ + public markRateLimited(keyHash: string) { + this.log.warn({ key: keyHash }, "Key rate limited"); + const key = this.keys.find((k) => k.hash === keyHash)!; + const now = Date.now(); + key.rateLimitedAt = now; + key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT; + } + + public remainingQuota() { + const activeKeys = this.keys.filter((k) => !k.isDisabled).length; + const allKeys = this.keys.length; + if (activeKeys === 0) return 0; + return Math.round((activeKeys / allKeys) * 100) / 100; + } + + public usageInUsd() { + return "$0.00 / ∞"; + } +} diff --git a/src/key-management/index.ts b/src/key-management/index.ts index 2214bb1..b2916f3 100644 --- a/src/key-management/index.ts +++ b/src/key-management/index.ts @@ -1,5 +1,65 @@ +import { OPENAI_SUPPORTED_MODELS, OpenAIModel } from "./openai/provider"; +import { + ANTHROPIC_SUPPORTED_MODELS, + AnthropicModel, +} from "./anthropic/provider"; import { KeyPool } from "./key-pool"; -export type { Key, Model } from "./key-pool"; +export type AIService = "openai" | "anthropic"; +export type Model = OpenAIModel | AnthropicModel; + +export interface Key { + /** The API key itself. Never log this, use `hash` instead. */ + readonly key: string; + /** The service that this key is for. */ + service: AIService; + /** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */ + isTrial: boolean; + /** Whether this key has been provisioned for GPT-4. */ + isGpt4: boolean; + /** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */ + isDisabled: boolean; + /** The number of prompts that have been sent with this key. */ + promptCount: number; + /** The time at which this key was last used. */ + lastUsed: number; + /** The time at which this key was last checked. */ + lastChecked: number; + /** Hash of the key, for logging and to find the key in the pool. */ + hash: string; +} + +/* +KeyPool and KeyProvider's similarities are a relic of the old design where +there was only a single KeyPool for OpenAI keys. Now that there are multiple +supported services, the service-specific functionality has been moved to +KeyProvider and KeyPool is just a wrapper around multiple KeyProviders, +delegating to the appropriate one based on the model requested. + +Existing code will continue to call methods on KeyPool, which routes them to +the appropriate KeyProvider or returns data aggregated across all KeyProviders +for service-agnostic functionality. +*/ + +export interface KeyProvider { + readonly service: AIService; + init(): void; + get(model: Model): T; + list(): Omit[]; + disable(key: T): void; + available(): number; + anyUnchecked(): boolean; + incrementPrompt(hash: string): void; + getLockoutPeriod(model: Model): number; + remainingQuota(options?: Record): number; + usageInUsd(options?: Record): string; + markRateLimited(hash: string): void; +} + export const keyPool = new KeyPool(); -export { SUPPORTED_MODELS } from "./key-pool"; +export const SUPPORTED_MODELS = [ + ...OPENAI_SUPPORTED_MODELS, + ...ANTHROPIC_SUPPORTED_MODELS, +] as const; +export type SupportedModel = (typeof SUPPORTED_MODELS)[number]; +export { OPENAI_SUPPORTED_MODELS, ANTHROPIC_SUPPORTED_MODELS }; diff --git a/src/key-management/key-pool.ts b/src/key-management/key-pool.ts index 323c486..5730e53 100644 --- a/src/key-management/key-pool.ts +++ b/src/key-management/key-pool.ts @@ -1,378 +1,102 @@ -/* Manages OpenAI API keys. Tracks usage, disables expired keys, and provides -round-robin access to keys. Keys are stored in the OPENAI_KEY environment -variable as a comma-separated list of keys. */ -import crypto from "crypto"; -import fs from "fs"; -import http from "http"; -import path from "path"; -import { config } from "../config"; -import { logger } from "../logger"; -import { KeyChecker } from "./key-checker"; - -// TODO: Made too many assumptions about OpenAI being the only provider and now -// this doesn't really work for Anthropic. Create a Provider interface and -// implement Pool, Checker, and Models for each provider. - -export type Model = OpenAIModel | AnthropicModel; -export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4"; -export type AnthropicModel = "claude-v1" | "claude-instant-v1"; -export const SUPPORTED_MODELS: readonly Model[] = [ - "gpt-3.5-turbo", - "gpt-4", - "claude-v1", - "claude-instant-v1", -] as const; - -export type Key = { - /** The OpenAI API key itself. */ - key: string; - /** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */ - isTrial: boolean; - /** Whether this key has been provisioned for GPT-4. */ - isGpt4: boolean; - /** Whether this key is currently disabled. We set this if we get a 429 or 401 response from OpenAI. */ - isDisabled: boolean; - /** Threshold at which a warning email will be sent by OpenAI. */ - softLimit: number; - /** Threshold at which the key will be disabled because it has reached the user-defined limit. */ - hardLimit: number; - /** The maximum quota allocated to this key by OpenAI. */ - systemHardLimit: number; - /** The current usage of this key. */ - usage: number; - /** The number of prompts that have been sent with this key. */ - promptCount: number; - /** The time at which this key was last used. */ - lastUsed: number; - /** The time at which this key was last checked. */ - lastChecked: number; - /** Key hash for displaying usage in the dashboard. */ - hash: string; - /** The time at which this key was last rate limited. */ - rateLimitedAt: number; - /** - * Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a - * number. - * Formatted as a `\d+(m|s)` string denoting the time until the limit resets. - * Specifically, it seems to indicate the time until the key's quota will be - * fully restored; the key may be usable before this time as the limit is a - * rolling window. - * - * Requests which return a 429 do not count against the quota. - * - * Requests which fail for other reasons (e.g. 401) count against the quota. - */ - rateLimitRequestsReset: number; - /** - * Last known X-RateLimit-Tokens-Reset header from OpenAI, converted to a - * number. - * Appears to follow the same format as `rateLimitRequestsReset`. - * - * Requests which fail do not count against the quota as they do not consume - * tokens. - */ - rateLimitTokensReset: number; -}; - -export type KeyUpdate = Omit< - Partial, - "key" | "hash" | "lastUsed" | "lastChecked" | "promptCount" ->; +import type * as http from "http"; +import { AnthropicKeyProvider } from "./anthropic/provider"; +import { Key, AIService, Model, KeyProvider } from "./index"; +import { OpenAIKeyProvider } from "./openai/provider"; export class KeyPool { - private keys: Key[] = []; - private checker?: KeyChecker; - private log = logger.child({ module: "key-pool" }); + private keyProviders: KeyProvider[] = []; constructor() { - const keyString = config.openaiKey; - if (!keyString?.trim()) { - throw new Error("OPENAI_KEY environment variable is not set"); - } - let bareKeys: string[]; - bareKeys = keyString.split(",").map((k) => k.trim()); - bareKeys = [...new Set(bareKeys)]; - for (const k of bareKeys) { - const newKey = { - key: k, - isGpt4: false, - isTrial: false, - isDisabled: false, - softLimit: 0, - hardLimit: 0, - systemHardLimit: 0, - usage: 0, - lastUsed: 0, - lastChecked: 0, - promptCount: 0, - hash: crypto.createHash("sha256").update(k).digest("hex").slice(0, 8), - rateLimitedAt: 0, - rateLimitRequestsReset: 0, - rateLimitTokensReset: 0, - }; - this.keys.push(newKey); - - } - this.log.info({ keyCount: this.keys.length }, "Loaded keys"); + this.keyProviders.push(new OpenAIKeyProvider()); + this.keyProviders.push(new AnthropicKeyProvider()); } public init() { - if (config.checkKeys) { - this.checker = new KeyChecker(this.keys, this.update.bind(this)); - this.checker.start(); - } - } - - /** - * Returns a list of all keys, with the key field removed. - * Don't mutate returned keys, use a KeyPool method instead. - **/ - public list() { - return this.keys.map((key) => { - return Object.freeze({ - ...key, - key: undefined, - }); - }); - } - - public get(model: Model) { - const needGpt4 = model.startsWith("gpt-4"); - const availableKeys = this.keys.filter( - (key) => !key.isDisabled && (!needGpt4 || key.isGpt4) - ); - if (availableKeys.length === 0) { - let message = "No keys available. Please add more keys."; - if (needGpt4) { - message = - "No GPT-4 keys available. Please add more keys or select a non-GPT-4 model."; - } - throw new Error(message); - } - - // Select a key, from highest priority to lowest priority: - // 1. Keys which are not rate limited - // a. We can assume any rate limits over a minute ago are expired - // b. If all keys were rate limited in the last minute, select the - // least recently rate limited key - // 2. Keys which are trials - // 3. Keys which have not been used in the longest time - - const now = Date.now(); - const rateLimitThreshold = 60 * 1000; - - const keysByPriority = availableKeys.sort((a, b) => { - const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold; - const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold; - - if (aRateLimited && !bRateLimited) return 1; - if (!aRateLimited && bRateLimited) return -1; - if (aRateLimited && bRateLimited) { - return a.rateLimitedAt - b.rateLimitedAt; - } - - if (a.isTrial && !b.isTrial) return -1; - if (!a.isTrial && b.isTrial) return 1; - - return a.lastUsed - b.lastUsed; - }); - - const selectedKey = keysByPriority[0]; - selectedKey.lastUsed = Date.now(); - - // When a key is selected, we rate-limit it for a brief period of time to - // prevent the queue processor from immediately flooding it with requests - // while the initial request is still being processed (which is when we will - // get new rate limit headers). - // Instead, we will let a request through every second until the key - // becomes fully saturated and locked out again. - selectedKey.rateLimitedAt = Date.now(); - selectedKey.rateLimitRequestsReset = 1000; - return { ...selectedKey }; - } - - /** Called by the key checker to update key information. */ - public update(keyHash: string, update: KeyUpdate) { - const keyFromPool = this.keys.find((k) => k.hash === keyHash)!; - Object.assign(keyFromPool, { ...update, lastChecked: Date.now() }); - // this.writeKeyStatus(); - } - - public disable(key: Key) { - const keyFromPool = this.keys.find((k) => k.key === key.key)!; - if (keyFromPool.isDisabled) return; - keyFromPool.isDisabled = true; - // If it's disabled just set the usage to the hard limit so it doesn't - // mess with the aggregate usage. - keyFromPool.usage = keyFromPool.hardLimit; - this.log.warn({ key: key.hash }, "Key disabled"); - } - - public available() { - return this.keys.filter((k) => !k.isDisabled).length; - } - - public anyUnchecked() { - return config.checkKeys && this.keys.some((key) => !key.lastChecked); - } - - /** - * Given a model, returns the period until a key will be available to service - * the request, or returns 0 if a key is ready immediately. - */ - public getLockoutPeriod(model: Model = "gpt-4"): number { - const needGpt4 = model.startsWith("gpt-4"); - const activeKeys = this.keys.filter( - (key) => !key.isDisabled && (!needGpt4 || key.isGpt4) - ); - - if (activeKeys.length === 0) { - // If there are no active keys for this model we can't fulfill requests. - // We'll return 0 to let the request through and return an error, - // otherwise the request will be stuck in the queue forever. - return 0; - } - - // A key is rate-limited if its `rateLimitedAt` plus the greater of its - // `rateLimitRequestsReset` and `rateLimitTokensReset` is after the - // current time. - - // If there are any keys that are not rate-limited, we can fulfill requests. - const now = Date.now(); - const rateLimitedKeys = activeKeys.filter((key) => { - const resetTime = Math.max( - key.rateLimitRequestsReset, - key.rateLimitTokensReset + this.keyProviders.forEach((provider) => provider.init()); + const availableKeys = this.available(); + if (availableKeys === 0) { + throw new Error( + "No keys loaded. Ensure either OPENAI_KEY or ANTHROPIC_KEY is set." ); - return now < key.rateLimitedAt + resetTime; - }).length; - const anyNotRateLimited = rateLimitedKeys < activeKeys.length; - - if (anyNotRateLimited) { - return 0; - } - - // If all keys are rate-limited, return the time until the first key is - // ready. - const timeUntilFirstReady = Math.min( - ...activeKeys.map((key) => { - const resetTime = Math.max( - key.rateLimitRequestsReset, - key.rateLimitTokensReset - ); - return key.rateLimitedAt + resetTime - now; - }) - ); - return timeUntilFirstReady; - } - - public markRateLimited(keyHash: string) { - this.log.warn({ key: keyHash }, "Key rate limited"); - const key = this.keys.find((k) => k.hash === keyHash)!; - key.rateLimitedAt = Date.now(); - } - - public incrementPrompt(keyHash?: string) { - if (!keyHash) return; - const key = this.keys.find((k) => k.hash === keyHash)!; - key.promptCount++; - } - - public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) { - const key = this.keys.find((k) => k.hash === keyHash)!; - const requestsReset = headers["x-ratelimit-reset-requests"]; - const tokensReset = headers["x-ratelimit-reset-tokens"]; - - // Sometimes OpenAI only sends one of the two rate limit headers, it's - // unclear why. - - if (requestsReset && typeof requestsReset === "string") { - this.log.info( - { key: key.hash, requestsReset }, - `Updating rate limit requests reset time` - ); - key.rateLimitRequestsReset = getResetDurationMillis(requestsReset); - } - - if (tokensReset && typeof tokensReset === "string") { - this.log.info( - { key: key.hash, tokensReset }, - `Updating rate limit tokens reset time` - ); - key.rateLimitTokensReset = getResetDurationMillis(tokensReset); - } - - if (!requestsReset && !tokensReset) { - this.log.warn( - { key: key.hash }, - `No rate limit headers in OpenAI response; skipping update` - ); - return; } } - /** Returns the remaining aggregate quota for all keys as a percentage. */ - public remainingQuota(gpt4 = false) { - const keys = this.keys.filter((k) => k.isGpt4 === gpt4); - if (keys.length === 0) return 0; - - const totalUsage = keys.reduce((acc, key) => { - // Keys can slightly exceed their quota - return acc + Math.min(key.usage, key.hardLimit); - }, 0); - const totalLimit = keys.reduce((acc, { hardLimit }) => acc + hardLimit, 0); - - return 1 - totalUsage / totalLimit; + public get(model: Model): Key { + const service = this.getService(model); + return this.getKeyProvider(service).get(model); } - /** Returns used and available usage in USD. */ - public usageInUsd(gpt4 = false) { - const keys = this.keys.filter((k) => k.isGpt4 === gpt4); - if (keys.length === 0) return "???"; + public list(): Omit[] { + return this.keyProviders.flatMap((provider) => provider.list()); + } - const totalHardLimit = keys.reduce( - (acc, { hardLimit }) => acc + hardLimit, + public disable(key: Key): void { + const service = this.getKeyProvider(key.service); + service.disable(key); + } + + // TODO: this probably needs to be scoped to a specific provider. I think the + // only code calling this is the error handler which needs to know how many + // more keys are available for the provider the user tried to use. + public available(): number { + return this.keyProviders.reduce( + (sum, provider) => sum + provider.available(), 0 ); - const totalUsage = keys.reduce((acc, key) => { - // Keys can slightly exceed their quota - return acc + Math.min(key.usage, key.hardLimit); - }, 0); - - return `$${totalUsage.toFixed(2)} / $${totalHardLimit.toFixed(2)}`; } - /** Writes key status to disk. */ - // public writeKeyStatus() { - // const keys = this.keys.map((key) => ({ - // key: key.key, - // isGpt4: key.isGpt4, - // usage: key.usage, - // hardLimit: key.hardLimit, - // isDisabled: key.isDisabled, - // })); - // fs.writeFileSync( - // path.join(__dirname, "..", "keys.json"), - // JSON.stringify(keys, null, 2) - // ); - // } -} - - - -/** - * Converts reset string ("21.0032s" or "21ms") to a number of milliseconds. - * Result is clamped to 10s even though the API returns up to 60s, because the - * API returns the time until the entire quota is reset, even if a key may be - * able to fulfill requests before then due to partial resets. - **/ -function getResetDurationMillis(resetDuration?: string): number { - const match = resetDuration?.match(/(\d+(\.\d+)?)(s|ms)/); - if (match) { - const [, time, , unit] = match; - const value = parseFloat(time); - const result = unit === "s" ? value * 1000 : value; - return Math.min(result, 10000); + public anyUnchecked(): boolean { + return this.keyProviders.some((provider) => provider.anyUnchecked()); + } + + public incrementPrompt(key: Key): void { + const provider = this.getKeyProvider(key.service); + provider.incrementPrompt(key.hash); + } + + public getLockoutPeriod(model: Model): number { + const service = this.getService(model); + return this.getKeyProvider(service).getLockoutPeriod(model); + } + + public markRateLimited(key: Key): void { + const provider = this.getKeyProvider(key.service); + provider.markRateLimited(key.hash); + } + + public updateRateLimits(key: Key, headers: http.IncomingHttpHeaders): void { + const provider = this.getKeyProvider(key.service); + if (provider instanceof OpenAIKeyProvider) { + provider.updateRateLimits(key.hash, headers); + } + } + + public remainingQuota( + service: AIService, + options?: Record + ): number { + return this.getKeyProvider(service).remainingQuota(options); + } + + public usageInUsd( + service: AIService, + options?: Record + ): string { + return this.getKeyProvider(service).usageInUsd(options); + } + + private getService(model: Model): AIService { + if (model.startsWith("gpt")) { + // https://platform.openai.com/docs/models/model-endpoint-compatibility + return "openai"; + } else if (model.startsWith("claude-")) { + // https://console.anthropic.com/docs/api/reference#parameters + return "anthropic"; + } + throw new Error(`Unknown service for model '${model}'`); + } + + private getKeyProvider(service: AIService): KeyProvider { + return this.keyProviders.find((provider) => provider.service === service)!; } - return 0; } diff --git a/src/key-management/key-checker.ts b/src/key-management/openai/checker.ts similarity index 91% rename from src/key-management/key-checker.ts rename to src/key-management/openai/checker.ts index 6b52505..c96c4e5 100644 --- a/src/key-management/key-checker.ts +++ b/src/key-management/openai/checker.ts @@ -1,7 +1,7 @@ import axios, { AxiosError } from "axios"; import { Configuration, OpenAIApi } from "openai"; -import { logger } from "../logger"; -import type { Key, KeyPool } from "./key-pool"; +import { logger } from "../../logger"; +import type { OpenAIKey, OpenAIKeyProvider } from "./provider"; const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const KEY_CHECK_PERIOD = 5 * 60 * 1000; // 5 minutes @@ -26,16 +26,16 @@ type OpenAIError = { error: { type: string; code: string; param: unknown; message: string }; }; -type UpdateFn = typeof KeyPool.prototype.update; +type UpdateFn = typeof OpenAIKeyProvider.prototype.update; -export class KeyChecker { - private readonly keys: Key[]; - private log = logger.child({ module: "key-checker" }); +export class OpenAIKeyChecker { + private readonly keys: OpenAIKey[]; + private log = logger.child({ module: "key-checker", service: "openai" }); private timeout?: NodeJS.Timeout; private updateKey: UpdateFn; private lastCheck = 0; - constructor(keys: Key[], updateKey: UpdateFn) { + constructor(keys: OpenAIKey[], updateKey: UpdateFn) { this.keys = keys; this.updateKey = updateKey; } @@ -110,7 +110,7 @@ export class KeyChecker { this.timeout = setTimeout(() => this.checkKey(oldestKey), delay); } - private async checkKey(key: Key) { + private async checkKey(key: OpenAIKey) { // It's possible this key might have been disabled while we were waiting // for the next check. if (key.isDisabled) { @@ -180,7 +180,7 @@ export class KeyChecker { } private async getProvisionedModels( - key: Key + key: OpenAIKey ): Promise<{ turbo: boolean; gpt4: boolean }> { const openai = new OpenAIApi(new Configuration({ apiKey: key.key })); const models = (await openai.listModels()!).data.data; @@ -189,7 +189,7 @@ export class KeyChecker { return { turbo, gpt4 }; } - private async getSubscription(key: Key) { + private async getSubscription(key: OpenAIKey) { const { data } = await axios.get( GET_SUBSCRIPTION_URL, { headers: { Authorization: `Bearer ${key.key}` } } @@ -197,8 +197,8 @@ export class KeyChecker { return data; } - private async getUsage(key: Key) { - const querystring = KeyChecker.getUsageQuerystring(key.isTrial); + private async getUsage(key: OpenAIKey) { + const querystring = OpenAIKeyChecker.getUsageQuerystring(key.isTrial); const url = `${GET_USAGE_URL}?${querystring}`; const { data } = await axios.get(url, { headers: { Authorization: `Bearer ${key.key}` }, @@ -206,8 +206,8 @@ export class KeyChecker { return parseFloat((data.total_usage / 100).toFixed(2)); } - private handleAxiosError(key: Key, error: AxiosError) { - if (error.response && KeyChecker.errorIsOpenAiError(error)) { + private handleAxiosError(key: OpenAIKey, error: AxiosError) { + if (error.response && OpenAIKeyChecker.errorIsOpenAiError(error)) { const { status, data } = error.response; if (status === 401) { this.log.warn( @@ -239,7 +239,7 @@ export class KeyChecker { * Trial key usage reporting is inaccurate, so we need to run an actual * completion to test them for liveness. */ - private async assertCanGenerate(key: Key): Promise { + private async assertCanGenerate(key: OpenAIKey): Promise { const openai = new OpenAIApi(new Configuration({ apiKey: key.key })); // This will throw an AxiosError if the key is invalid or out of quota. await openai.createChatCompletion({ diff --git a/src/key-management/openai/provider.ts b/src/key-management/openai/provider.ts new file mode 100644 index 0000000..7161ca3 --- /dev/null +++ b/src/key-management/openai/provider.ts @@ -0,0 +1,360 @@ +/* Manages OpenAI API keys. Tracks usage, disables expired keys, and provides +round-robin access to keys. Keys are stored in the OPENAI_KEY environment +variable as a comma-separated list of keys. */ +import crypto from "crypto"; +import fs from "fs"; +import http from "http"; +import path from "path"; +import { KeyProvider, Key, Model } from "../index"; +import { config } from "../../config"; +import { logger } from "../../logger"; +import { OpenAIKeyChecker } from "./checker"; + +export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4"; +export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [ + "gpt-3.5-turbo", + "gpt-4", +] as const; + +export interface OpenAIKey extends Key { + readonly service: "openai"; + /** The current usage of this key. */ + usage: number; + /** Threshold at which a warning email will be sent by OpenAI. */ + softLimit: number; + /** Threshold at which the key will be disabled because it has reached the user-defined limit. */ + hardLimit: number; + /** The maximum quota allocated to this key by OpenAI. */ + systemHardLimit: number; + /** The time at which this key was last rate limited. */ + rateLimitedAt: number; + /** + * Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a + * number. + * Formatted as a `\d+(m|s)` string denoting the time until the limit resets. + * Specifically, it seems to indicate the time until the key's quota will be + * fully restored; the key may be usable before this time as the limit is a + * rolling window. + * + * Requests which return a 429 do not count against the quota. + * + * Requests which fail for other reasons (e.g. 401) count against the quota. + */ + rateLimitRequestsReset: number; + /** + * Last known X-RateLimit-Tokens-Reset header from OpenAI, converted to a + * number. + * Appears to follow the same format as `rateLimitRequestsReset`. + * + * Requests which fail do not count against the quota as they do not consume + * tokens. + */ + rateLimitTokensReset: number; +} + +export type OpenAIKeyUpdate = Omit< + Partial, + "key" | "hash" | "lastUsed" | "lastChecked" | "promptCount" +>; + +export class OpenAIKeyProvider implements KeyProvider { + readonly service = "openai" as const; + + private keys: OpenAIKey[] = []; + private checker?: OpenAIKeyChecker; + private log = logger.child({ module: "key-provider", service: this.service }); + + constructor() { + const keyString = config.openaiKey?.trim(); + if (!keyString) { + this.log.warn("OPENAI_KEY is not set. OpenAI API will not be available."); + return; + } + let bareKeys: string[]; + bareKeys = keyString.split(",").map((k) => k.trim()); + bareKeys = [...new Set(bareKeys)]; + for (const k of bareKeys) { + const newKey = { + key: k, + service: "openai" as const, + isGpt4: false, + isTrial: false, + isDisabled: false, + softLimit: 0, + hardLimit: 0, + systemHardLimit: 0, + usage: 0, + lastUsed: 0, + lastChecked: 0, + promptCount: 0, + hash: `oai-${crypto + .createHash("sha256") + .update(k) + .digest("hex") + .slice(0, 8)}`, + rateLimitedAt: 0, + rateLimitRequestsReset: 0, + rateLimitTokensReset: 0, + }; + this.keys.push(newKey); + } + this.log.info({ keyCount: this.keys.length }, "Loaded OpenAI keys."); + } + + public init() { + if (config.checkKeys) { + this.checker = new OpenAIKeyChecker(this.keys, this.update.bind(this)); + this.checker.start(); + } + } + + /** + * Returns a list of all keys, with the key field removed. + * Don't mutate returned keys, use a KeyPool method instead. + **/ + public list() { + return this.keys.map((key) => { + return Object.freeze({ + ...key, + key: undefined, + }); + }); + } + + public get(model: Model) { + const needGpt4 = model.startsWith("gpt-4"); + const availableKeys = this.keys.filter( + (key) => !key.isDisabled && (!needGpt4 || key.isGpt4) + ); + if (availableKeys.length === 0) { + let message = needGpt4 + ? "No active OpenAI keys available." + : "No GPT-4 keys available. Try selecting a non-GPT-4 model."; + throw new Error(message); + } + + // Select a key, from highest priority to lowest priority: + // 1. Keys which are not rate limited + // a. We ignore rate limits from over a minute ago + // b. If all keys were rate limited in the last minute, select the + // least recently rate limited key + // 2. Keys which are trials + // 3. Keys which have not been used in the longest time + + const now = Date.now(); + const rateLimitThreshold = 60 * 1000; + + const keysByPriority = availableKeys.sort((a, b) => { + const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold; + const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold; + + if (aRateLimited && !bRateLimited) return 1; + if (!aRateLimited && bRateLimited) return -1; + if (aRateLimited && bRateLimited) { + return a.rateLimitedAt - b.rateLimitedAt; + } + + if (a.isTrial && !b.isTrial) return -1; + if (!a.isTrial && b.isTrial) return 1; + + return a.lastUsed - b.lastUsed; + }); + + const selectedKey = keysByPriority[0]; + selectedKey.lastUsed = now; + + // When a key is selected, we rate-limit it for a brief period of time to + // prevent the queue processor from immediately flooding it with requests + // while the initial request is still being processed (which is when we will + // get new rate limit headers). + // Instead, we will let a request through every second until the key + // becomes fully saturated and locked out again. + selectedKey.rateLimitedAt = now; + selectedKey.rateLimitRequestsReset = 1000; + return { ...selectedKey }; + } + + /** Called by the key checker to update key information. */ + public update(keyHash: string, update: OpenAIKeyUpdate) { + const keyFromPool = this.keys.find((k) => k.hash === keyHash)!; + Object.assign(keyFromPool, { ...update, lastChecked: Date.now() }); + // this.writeKeyStatus(); + } + + /** Disables a key, or does nothing if the key isn't in this pool. */ + public disable(key: Key) { + const keyFromPool = this.keys.find((k) => k.key === key.key); + if (!keyFromPool || keyFromPool.isDisabled) return; + keyFromPool.isDisabled = true; + // If it's disabled just set the usage to the hard limit so it doesn't + // mess with the aggregate usage. + keyFromPool.usage = keyFromPool.hardLimit; + this.log.warn({ key: key.hash }, "Key disabled"); + } + + public available() { + return this.keys.filter((k) => !k.isDisabled).length; + } + + public anyUnchecked() { + return !!config.checkKeys && this.keys.some((key) => !key.lastChecked); + } + + /** + * Given a model, returns the period until a key will be available to service + * the request, or returns 0 if a key is ready immediately. + */ + public getLockoutPeriod(model: Model = "gpt-4"): number { + const needGpt4 = model.startsWith("gpt-4"); + const activeKeys = this.keys.filter( + (key) => !key.isDisabled && (!needGpt4 || key.isGpt4) + ); + + if (activeKeys.length === 0) { + // If there are no active keys for this model we can't fulfill requests. + // We'll return 0 to let the request through and return an error, + // otherwise the request will be stuck in the queue forever. + return 0; + } + + // A key is rate-limited if its `rateLimitedAt` plus the greater of its + // `rateLimitRequestsReset` and `rateLimitTokensReset` is after the + // current time. + + // If there are any keys that are not rate-limited, we can fulfill requests. + const now = Date.now(); + const rateLimitedKeys = activeKeys.filter((key) => { + const resetTime = Math.max( + key.rateLimitRequestsReset, + key.rateLimitTokensReset + ); + return now < key.rateLimitedAt + resetTime; + }).length; + const anyNotRateLimited = rateLimitedKeys < activeKeys.length; + + if (anyNotRateLimited) { + return 0; + } + + // If all keys are rate-limited, return the time until the first key is + // ready. + const timeUntilFirstReady = Math.min( + ...activeKeys.map((key) => { + const resetTime = Math.max( + key.rateLimitRequestsReset, + key.rateLimitTokensReset + ); + return key.rateLimitedAt + resetTime - now; + }) + ); + return timeUntilFirstReady; + } + + public markRateLimited(keyHash: string) { + this.log.warn({ key: keyHash }, "Key rate limited"); + const key = this.keys.find((k) => k.hash === keyHash)!; + key.rateLimitedAt = Date.now(); + } + + public incrementPrompt(keyHash?: string) { + const key = this.keys.find((k) => k.hash === keyHash); + if (!key) return; + key.promptCount++; + } + + public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) { + const key = this.keys.find((k) => k.hash === keyHash)!; + const requestsReset = headers["x-ratelimit-reset-requests"]; + const tokensReset = headers["x-ratelimit-reset-tokens"]; + + // Sometimes OpenAI only sends one of the two rate limit headers, it's + // unclear why. + + if (requestsReset && typeof requestsReset === "string") { + this.log.info( + { key: key.hash, requestsReset }, + `Updating rate limit requests reset time` + ); + key.rateLimitRequestsReset = getResetDurationMillis(requestsReset); + } + + if (tokensReset && typeof tokensReset === "string") { + this.log.info( + { key: key.hash, tokensReset }, + `Updating rate limit tokens reset time` + ); + key.rateLimitTokensReset = getResetDurationMillis(tokensReset); + } + + if (!requestsReset && !tokensReset) { + this.log.warn( + { key: key.hash }, + `No rate limit headers in OpenAI response; skipping update` + ); + return; + } + } + + /** Returns the remaining aggregate quota for all keys as a percentage. */ + public remainingQuota({ gpt4 }: { gpt4: boolean } = { gpt4: false }): number { + const keys = this.keys.filter((k) => k.isGpt4 === gpt4); + if (keys.length === 0) return 0; + + const totalUsage = keys.reduce((acc, key) => { + // Keys can slightly exceed their quota + return acc + Math.min(key.usage, key.hardLimit); + }, 0); + const totalLimit = keys.reduce((acc, { hardLimit }) => acc + hardLimit, 0); + + return 1 - totalUsage / totalLimit; + } + + /** Returns used and available usage in USD. */ + public usageInUsd({ gpt4 }: { gpt4: boolean } = { gpt4: false }): string { + const keys = this.keys.filter((k) => k.isGpt4 === gpt4); + if (keys.length === 0) return "???"; + + const totalHardLimit = keys.reduce( + (acc, { hardLimit }) => acc + hardLimit, + 0 + ); + const totalUsage = keys.reduce((acc, key) => { + // Keys can slightly exceed their quota + return acc + Math.min(key.usage, key.hardLimit); + }, 0); + + return `$${totalUsage.toFixed(2)} / $${totalHardLimit.toFixed(2)}`; + } + + /** Writes key status to disk. */ + // public writeKeyStatus() { + // const keys = this.keys.map((key) => ({ + // key: key.key, + // isGpt4: key.isGpt4, + // usage: key.usage, + // hardLimit: key.hardLimit, + // isDisabled: key.isDisabled, + // })); + // fs.writeFileSync( + // path.join(__dirname, "..", "keys.json"), + // JSON.stringify(keys, null, 2) + // ); + // } +} + +/** + * Converts reset string ("21.0032s" or "21ms") to a number of milliseconds. + * Result is clamped to 10s even though the API returns up to 60s, because the + * API returns the time until the entire quota is reset, even if a key may be + * able to fulfill requests before then due to partial resets. + **/ +function getResetDurationMillis(resetDuration?: string): number { + const match = resetDuration?.match(/(\d+(\.\d+)?)(s|ms)/); + if (match) { + const [, time, , unit] = match; + const value = parseFloat(time); + const result = unit === "s" ? value * 1000 : value; + return Math.min(result, 10000); + } + return 0; +} diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts new file mode 100644 index 0000000..585764f --- /dev/null +++ b/src/proxy/anthropic.ts @@ -0,0 +1,171 @@ +import { Request, Router } from "express"; +import * as http from "http"; +import { createProxyMiddleware } from "http-proxy-middleware"; +import { config } from "../config"; +import { logger } from "../logger"; +import { + addKey, + finalizeBody, + languageFilter, + limitOutputTokens, + transformOutboundPayload, +} from "./middleware/request"; +import { + ProxyResHandlerWithBody, + createOnProxyResHandler, + handleInternalError, +} from "./middleware/response"; +import { createQueueMiddleware } from "./queue"; + +const rewriteAnthropicRequest = ( + proxyReq: http.ClientRequest, + req: Request, + res: http.ServerResponse +) => { + req.api = "anthropic"; + const rewriterPipeline = [ + addKey, + languageFilter, + limitOutputTokens, + transformOutboundPayload, + finalizeBody, + ]; + + try { + for (const rewriter of rewriterPipeline) { + rewriter(proxyReq, req, res, {}); + } + } catch (error) { + req.log.error(error, "Error while executing proxy rewriter"); + proxyReq.destroy(error as Error); + } +}; + +/** Only used for non-streaming requests. */ +const anthropicResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + if (config.promptLogging) { + const host = req.get("host"); + body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`; + } + + if (!req.originalUrl.includes("/v1/complete")) { + req.log.info("Transforming Anthropic response to OpenAI format"); + body = transformAnthropicResponse(body); + } + res.status(200).json(body); +}; + +/** + * Transforms a model response from the Anthropic API to match those from the + * OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This + * is only used for non-streaming requests as streaming requests are handled + * on-the-fly. + */ +function transformAnthropicResponse( + anthropicBody: Record +): Record { + return { + id: "ant-" + anthropicBody.log_id, + object: "chat.completion", + created: Date.now(), + model: anthropicBody.model, + usage: { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + choices: [ + { + message: { + role: "assistant", + content: anthropicBody.completion?.trim(), + }, + finish_reason: anthropicBody.stop_reason, + index: 0, + }, + ], + }; +} + +const anthropicProxy = createProxyMiddleware({ + target: "https://api.anthropic.com", + changeOrigin: true, + on: { + proxyReq: rewriteAnthropicRequest, + proxyRes: createOnProxyResHandler([anthropicResponseHandler]), + error: handleInternalError, + }, + selfHandleResponse: true, + logger, + pathRewrite: { + // If the user sends a request to /v1/chat/completions (the OpenAI endpoint) + // we will transform the payload and rewrite the path to /v1/complete. + "^/v1/chat/completions": "/v1/complete", + }, +}); +const queuedAnthropicProxy = createQueueMiddleware(anthropicProxy); + +const anthropicRouter = Router(); +anthropicRouter.use((req, _res, next) => { + if (!req.path.startsWith("/v1/")) { + req.url = `/v1${req.url}`; + } + next(); +}); +anthropicRouter.get("/v1/models", (req, res) => { + res.json(buildFakeModelsResponse()); +}); +anthropicRouter.post("/v1/complete", queuedAnthropicProxy); +// This is the OpenAI endpoint, to let users send OpenAI-formatted requests +// to the Anthropic API. We need to rewrite them first. +anthropicRouter.post("/v1/chat/completions", queuedAnthropicProxy); +// Redirect browser requests to the homepage. +anthropicRouter.get("*", (req, res, next) => { + const isBrowser = req.headers["user-agent"]?.includes("Mozilla"); + if (isBrowser) { + res.redirect("/"); + } else { + next(); + } +}); + +function buildFakeModelsResponse() { + const claudeVariants = [ + "claude-v1", + "claude-v1-100k", + "claude-instant-v1", + "claude-instant-v1-100k", + "claude-v1.3", + "claude-v1.3-100k", + "claude-v1.2", + "claude-v1.0", + "claude-instant-v1.1", + "claude-instant-v1.1-100k", + "claude-instant-v1.0", + ]; + + const models = claudeVariants.map((id) => ({ + id, + object: "model", + created: new Date().getTime(), + owned_by: "anthropic", + permission: [], + root: "claude", + parent: null, + })); + + return { + models, + }; +} + +export const anthropic = anthropicRouter; diff --git a/src/proxy/kobold.ts b/src/proxy/kobold.ts index e278e2f..2337558 100644 --- a/src/proxy/kobold.ts +++ b/src/proxy/kobold.ts @@ -9,7 +9,6 @@ import { logger } from "../logger"; import { ipLimiter } from "./rate-limit"; import { addKey, - checkStreaming, finalizeBody, languageFilter, limitOutputTokens, @@ -41,11 +40,11 @@ const rewriteRequest = ( } req.api = "kobold"; + req.body.stream = false; const rewriterPipeline = [ addKey, transformKoboldPayload, languageFilter, - checkStreaming, limitOutputTokens, finalizeBody, ]; diff --git a/src/proxy/middleware/request/add-key.ts b/src/proxy/middleware/request/add-key.ts index 88079e4..98edfe9 100644 --- a/src/proxy/middleware/request/add-key.ts +++ b/src/proxy/middleware/request/add-key.ts @@ -1,45 +1,52 @@ -import { Key, Model, keyPool, SUPPORTED_MODELS } from "../../../key-management"; +import { Key, keyPool } from "../../../key-management"; import type { ExpressHttpProxyReqCallback } from "."; -/** Add an OpenAI key from the pool to the request. */ +/** Add a key that can service this request to the request object. */ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => { let assignedKey: Key; - // Not all clients request a particular model. - // If they request a model, just use that. - // If they don't request a model, use a GPT-4 key if there is an active one, - // otherwise use a GPT-3.5 key. - - // TODO: Anthropic mode should prioritize Claude over Claude Instant. - // Each provider needs to define some priority order for their models. - - if (bodyHasModel(req.body)) { - assignedKey = keyPool.get(req.body.model); - } else { - try { - assignedKey = keyPool.get("gpt-4"); - } catch { - assignedKey = keyPool.get("gpt-3.5-turbo"); - } + if (!req.body?.model) { + throw new Error("You must specify a model with your request."); } + + // This should happen somewhere else but addKey is guaranteed to run first. + req.isStreaming = req.body.stream === true || req.body.stream === "true"; + req.body.stream = req.isStreaming; + + // Anthropic support has a special endpoint that accepts OpenAI-formatted + // requests and translates them into Anthropic requests. On this endpoint, + // the requested model is an OpenAI one even though we're actually sending + // an Anthropic request. + // For such cases, ignore the requested model entirely. + // Real Anthropic requests come in via /proxy/anthropic/v1/complete + // The OpenAI-compatible endpoint is /proxy/anthropic/v1/chat/completions + + const openaiCompatible = + req.originalUrl === "/proxy/anthropic/v1/chat/completions"; + if (openaiCompatible) { + req.log.debug("Using an Anthropic key for an OpenAI-compatible request"); + req.api = "openai"; + // We don't assign the model here, that will happen when transforming the + // request body. + assignedKey = keyPool.get("claude-v1"); + } else { + assignedKey = keyPool.get(req.body.model); + } + req.key = assignedKey; req.log.info( { key: assignedKey.hash, model: req.body?.model, - isGpt4: assignedKey.isGpt4, + fromApi: req.api, + toApi: assignedKey.service, }, "Assigned key to request" ); - // TODO: Requests to Anthropic models use `X-API-Key`. - proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`); + if (assignedKey.service === "anthropic") { + proxyReq.setHeader("X-API-Key", assignedKey.key); + } else { + proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`); + } }; - -function bodyHasModel(body: any): body is { model: Model } { - // Model names can have suffixes indicating the frozen release version but - // OpenAI and Anthropic will use the latest version if you omit the suffix. - const isSupportedModel = (model: string) => - SUPPORTED_MODELS.some((supported) => model.startsWith(supported)); - return typeof body?.model === "string" && isSupportedModel(body.model); -} diff --git a/src/proxy/middleware/request/check-streaming.ts b/src/proxy/middleware/request/check-streaming.ts deleted file mode 100644 index 9d30bb9..0000000 --- a/src/proxy/middleware/request/check-streaming.ts +++ /dev/null @@ -1,24 +0,0 @@ -import { ExpressHttpProxyReqCallback, isCompletionRequest } from "."; - -/** - * If a stream is requested, mark the request as such so the response middleware - * knows to use the alternate EventSource response handler. - * Kobold requests can't currently be streamed as they use a different event - * format than the OpenAI API and we need to rewrite the events as they come in, - * which I have not yet implemented. - */ -export const checkStreaming: ExpressHttpProxyReqCallback = (_proxyReq, req) => { - const streamableApi = req.api !== "kobold"; - if (isCompletionRequest(req) && req.body?.stream) { - if (!streamableApi) { - req.log.warn( - { api: req.api, key: req.key?.hash }, - `Streaming requested, but ${req.api} streaming is not supported.` - ); - req.body.stream = false; - return; - } - req.body.stream = true; - req.isStreaming = true; - } -}; diff --git a/src/proxy/middleware/request/index.ts b/src/proxy/middleware/request/index.ts index a8a4c00..19328fb 100644 --- a/src/proxy/middleware/request/index.ts +++ b/src/proxy/middleware/request/index.ts @@ -3,20 +3,23 @@ import type { ClientRequest } from "http"; import type { ProxyReqCallback } from "http-proxy"; export { addKey } from "./add-key"; -export { checkStreaming } from "./check-streaming"; export { finalizeBody } from "./finalize-body"; export { languageFilter } from "./language-filter"; export { limitCompletions } from "./limit-completions"; export { limitOutputTokens } from "./limit-output-tokens"; export { transformKoboldPayload } from "./transform-kobold-payload"; +export { transformOutboundPayload } from "./transform-outbound-payload"; const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions"; +const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete"; -/** Returns true if we're making a chat completion request. */ +/** Returns true if we're making a request to a completion endpoint. */ export function isCompletionRequest(req: Request) { return ( req.method === "POST" && - req.path.startsWith(OPENAI_CHAT_COMPLETION_ENDPOINT) + [OPENAI_CHAT_COMPLETION_ENDPOINT, ANTHROPIC_COMPLETION_ENDPOINT].some( + (endpoint) => req.path.startsWith(endpoint) + ) ); } diff --git a/src/proxy/middleware/request/limit-completions.ts b/src/proxy/middleware/request/limit-completions.ts index 0261b06..6f9da52 100644 --- a/src/proxy/middleware/request/limit-completions.ts +++ b/src/proxy/middleware/request/limit-completions.ts @@ -1,6 +1,9 @@ import { ExpressHttpProxyReqCallback, isCompletionRequest } from "."; -/** Don't allow multiple completions to be requested to prevent abuse. */ +/** + * Don't allow multiple completions to be requested to prevent abuse. + * OpenAI-only, Anthropic provides no such parameter. + **/ export const limitCompletions: ExpressHttpProxyReqCallback = ( _proxyReq, req diff --git a/src/proxy/middleware/request/limit-output-tokens.ts b/src/proxy/middleware/request/limit-output-tokens.ts index 91f91e1..2e1c9e1 100644 --- a/src/proxy/middleware/request/limit-output-tokens.ts +++ b/src/proxy/middleware/request/limit-output-tokens.ts @@ -1,29 +1,43 @@ +import { Request } from "express"; import { config } from "../../../config"; -import { logger } from "../../../logger"; import { ExpressHttpProxyReqCallback, isCompletionRequest } from "."; const MAX_TOKENS = config.maxOutputTokens; -/** Enforce a maximum number of tokens requested from OpenAI. */ +/** Enforce a maximum number of tokens requested from the model. */ export const limitOutputTokens: ExpressHttpProxyReqCallback = ( _proxyReq, req ) => { if (isCompletionRequest(req) && req.body?.max_tokens) { - // convert bad or missing input to a MAX_TOKENS - if (typeof req.body.max_tokens !== "number") { - logger.warn( - `Invalid max_tokens value: ${req.body.max_tokens}. Using ${MAX_TOKENS}` + const requestedMaxTokens = getMaxTokensFromRequest(req); + let maxTokens = requestedMaxTokens; + + if (typeof requestedMaxTokens !== "number") { + req.log.warn( + { requestedMaxTokens, clampedMaxTokens: MAX_TOKENS }, + "Invalid max tokens value. Using default value." ); - req.body.max_tokens = MAX_TOKENS; + maxTokens = MAX_TOKENS; } - const originalTokens = req.body.max_tokens; - req.body.max_tokens = Math.min(req.body.max_tokens, MAX_TOKENS); - if (originalTokens !== req.body.max_tokens) { - logger.warn( - `Limiting max_tokens from ${originalTokens} to ${req.body.max_tokens}` + // TODO: this is not going to scale well, need to implement a better way + // of translating request parameters from one API to another. + maxTokens = Math.min(maxTokens, MAX_TOKENS); + if (req.key!.service === "openai") { + req.body.max_tokens = maxTokens; + } else if (req.key!.service === "anthropic") { + req.body.max_tokens_to_sample = maxTokens; + } + + if (requestedMaxTokens !== maxTokens) { + req.log.warn( + `Limiting max tokens from ${requestedMaxTokens} to ${maxTokens}` ); } } }; + +function getMaxTokensFromRequest(req: Request) { + return (req.body?.max_tokens || req.body?.max_tokens_to_sample) ?? MAX_TOKENS; +} diff --git a/src/proxy/middleware/request/transform-kobold-payload.ts b/src/proxy/middleware/request/transform-kobold-payload.ts index 9e37f9b..e0c14f7 100644 --- a/src/proxy/middleware/request/transform-kobold-payload.ts +++ b/src/proxy/middleware/request/transform-kobold-payload.ts @@ -1,3 +1,8 @@ +/** + * Transforms a KoboldAI payload into an OpenAI payload. + * @deprecated Kobold input format isn't supported anymore as all popular + * frontends support reverse proxies or changing their base URL. + */ import { logger } from "../../../logger"; import type { ExpressHttpProxyReqCallback } from "."; @@ -63,6 +68,10 @@ export const transformKoboldPayload: ExpressHttpProxyReqCallback = ( _proxyReq, req ) => { + if (req.api !== "kobold") { + throw new Error("transformKoboldPayload called for non-kobold request."); + } + const { body } = req; const { prompt, max_length, rep_pen, top_p, temperature } = body; diff --git a/src/proxy/middleware/request/transform-outbound-payload.ts b/src/proxy/middleware/request/transform-outbound-payload.ts new file mode 100644 index 0000000..df3f28c --- /dev/null +++ b/src/proxy/middleware/request/transform-outbound-payload.ts @@ -0,0 +1,125 @@ +import { Request } from "express"; +import { z } from "zod"; +import type { ExpressHttpProxyReqCallback } from "."; + +// https://console.anthropic.com/docs/api/reference#-v1-complete +const AnthropicV1CompleteSchema = z.object({ + model: z.string().regex(/^claude-/), + prompt: z.string(), + max_tokens_to_sample: z.number(), + stop_sequences: z.array(z.string()).optional(), + stream: z.boolean().optional().default(false), + temperature: z.number().optional().default(1), + top_k: z.number().optional().default(-1), + top_p: z.number().optional().default(-1), + metadata: z.any().optional(), +}); + +// https://platform.openai.com/docs/api-reference/chat/create +const OpenAIV1ChatCompletionSchema = z.object({ + model: z.string().regex(/^gpt/), + messages: z.array( + z.object({ + role: z.enum(["system", "user", "assistant"]), + content: z.string(), + name: z.string().optional(), + }) + ), + temperature: z.number().optional().default(1), + top_p: z.number().optional().default(1), + n: z.literal(1).optional(), + stream: z.boolean().optional().default(false), + stop: z.union([z.string(), z.array(z.string())]).optional(), + max_tokens: z.number().optional(), + frequency_penalty: z.number().optional().default(0), + presence_penalty: z.number().optional().default(0), + logit_bias: z.any().optional(), + user: z.string().optional(), +}); + +/** Transforms an incoming request body to one that matches the target API. */ +export const transformOutboundPayload: ExpressHttpProxyReqCallback = ( + _proxyReq, + req +) => { + if (req.retryCount > 0) { + // We've already transformed the payload once, so don't do it again. + return; + } + + const inboundService = req.api; + const outboundService = req.key!.service; + + if (inboundService === outboundService) { + return; + } + + // Not supported yet and unnecessary as everything supports OpenAI. + if (inboundService === "anthropic" && outboundService === "openai") { + throw new Error( + "Anthropic -> OpenAI request transformation not supported. Provide an OpenAI-compatible payload, or use the /claude endpoint." + ); + } + + if (inboundService === "openai" && outboundService === "anthropic") { + req.body = openaiToAnthropic(req.body, req); + return; + } + + throw new Error( + `Unsupported transformation: ${inboundService} -> ${outboundService}` + ); +}; + +function openaiToAnthropic(body: any, req: Request) { + const result = OpenAIV1ChatCompletionSchema.safeParse(body); + if (!result.success) { + // don't log the prompt + const { messages, ...params } = body; + req.log.error( + { issues: result.error.issues, params }, + "Invalid OpenAI-to-Anthropic request" + ); + throw result.error; + } + + const { messages, ...rest } = result.data; + const prompt = + result.data.messages + .map((m) => { + let role: string = m.role; + if (role === "assistant") { + role = "Assistant"; + } else if (role === "system") { + role = "System"; + } else if (role === "user") { + role = "Human"; + } + // https://console.anthropic.com/docs/prompt-design + // `name` isn't supported by Anthropic but we can still try to use it. + return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${ + m.content + }`; + }) + .join("") + "\n\nAssistant: "; + + // When translating from OpenAI to Anthropic, we obviously can't use the + // provided OpenAI model name as-is. We will instead select a Claude model, + // choosing either the 100k token model or the 9k token model depending on + // the length of the prompt. I'm not bringing in the full OpenAI tokenizer for + // this so we'll use Anthropic's guideline of ~28000 characters to about 8k + // tokens (https://console.anthropic.com/docs/prompt-design#prompt-length) + // as the cutoff, minus a little bit for safety. + + // For smaller prompts we use 1.1 because it's less cucked. + // For big prompts (v1, auto-selects the latest model) is all we can use. + const model = prompt.length > 25000 ? "claude-v1-100k" : "claude-v1.1"; + + return { + ...rest, + model, + prompt, + max_tokens_to_sample: rest.max_tokens, + stop_sequences: rest.stop, + }; +} diff --git a/src/proxy/middleware/response/handle-streamed-response.ts b/src/proxy/middleware/response/handle-streamed-response.ts index a9431ca..a218356 100644 --- a/src/proxy/middleware/response/handle-streamed-response.ts +++ b/src/proxy/middleware/response/handle-streamed-response.ts @@ -1,6 +1,29 @@ -import { Response } from "express"; +import { Request, Response } from "express"; import * as http from "http"; import { RawResponseBodyHandler, decodeResponseBody } from "."; +import { buildFakeSseMessage } from "../../queue"; + +type OpenAiChatCompletionResponse = { + id: string; + object: string; + created: number; + model: string; + choices: { + message: { role: string; content: string }; + finish_reason: string | null; + index: number; + }[]; +}; + +type AnthropicCompletionResponse = { + completion: string; + stop_reason: string; + truncated: boolean; + stop: any; + model: string; + log_id: string; + exception: null; +}; /** * Consume the SSE stream and forward events to the client. Once the stream is @@ -11,18 +34,28 @@ import { RawResponseBodyHandler, decodeResponseBody } from "."; * in the event a streamed request results in a non-200 response, we need to * fall back to the non-streaming response handler so that the error handler * can inspect the error response. + * + * Currently most frontends don't support Anthropic streaming, so users can opt + * to send requests for Claude models via an endpoint that accepts OpenAI- + * compatible requests and translates the received Anthropic SSE events into + * OpenAI ones, essentially pretending to be an OpenAI streaming API. */ export const handleStreamedResponse: RawResponseBodyHandler = async ( proxyRes, req, res ) => { + // If these differ, the user is using the OpenAI-compatibile endpoint, so + // we need to translate the SSE events into OpenAI completion events for their + // frontend. + const fromApi = req.api; + const toApi = req.key!.service; if (!req.isStreaming) { req.log.error( { api: req.api, key: req.key?.hash }, - `handleEventSource called for non-streaming request, which isn't valid.` + `handleStreamedResponse called for non-streaming request, which isn't valid.` ); - throw new Error("handleEventSource called for non-streaming request."); + throw new Error("handleStreamedResponse called for non-streaming request."); } if (proxyRes.statusCode !== 200) { @@ -53,42 +86,81 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( res.flushHeaders(); } - const chunks: Buffer[] = []; - proxyRes.on("data", (chunk) => { - chunks.push(chunk); - res.write(chunk); - }); + const fullChunks: string[] = []; + let chunkBuffer: string[] = []; + let messageBuffer = ""; + let lastPosition = 0; + + type ProxyResHandler = (...args: T[]) => void; + function withErrorHandling(fn: ProxyResHandler) { + return (...args: T[]) => { + try { + fn(...args); + } catch (error) { + proxyRes.emit("error", error); + } + }; + } + + proxyRes.on( + "data", + withErrorHandling((chunk) => { + // We may receive multiple (or partial) SSE messages in a single chunk, so + // we need to buffer and emit seperate stream events for full messages so + // we can parse/transform them properly. + const str = chunk.toString(); + chunkBuffer.push(str); + + const newMessages = (messageBuffer + chunkBuffer.join("")).split( + /\r?\n\r?\n/ // Anthropic uses CRLF line endings (out-of-spec btw) + ); + chunkBuffer = []; + messageBuffer = newMessages.pop() || ""; + + for (const message of newMessages) { + proxyRes.emit("full-sse-event", message); + } + }) + ); + + proxyRes.on( + "full-sse-event", + withErrorHandling((data) => { + const { event, position } = transformEvent( + data, + fromApi, + toApi, + lastPosition + ); + fullChunks.push(event); + lastPosition = position; + res.write(event + "\n\n"); + }) + ); + + proxyRes.on( + "end", + withErrorHandling(() => { + let finalBody = convertEventsToFinalResponse(fullChunks, req); + req.log.info( + { api: req.api, key: req.key?.hash }, + `Finished proxying SSE stream.` + ); + res.end(); + resolve(finalBody); + }) + ); - proxyRes.on("end", () => { - const finalBody = convertEventsToOpenAiResponse(chunks); - req.log.info( - { api: req.api, key: req.key?.hash }, - `Finished proxying SSE stream.` - ); - res.end(); - resolve(finalBody); - }); proxyRes.on("error", (err) => { req.log.error( { error: err, api: req.api, key: req.key?.hash }, `Error while streaming response.` ); - // OAI's spec doesn't allow for error events and clients wouldn't know - // what to do with them anyway, so we'll just send a completion event - // with the error message. - const fakeErrorEvent = { - id: "chatcmpl-error", - object: "chat.completion.chunk", - created: Date.now(), - model: "", - choices: [ - { - delta: { content: "[Proxy streaming error: " + err.message + "]" }, - index: 0, - finish_reason: "error", - }, - ], - }; + const fakeErrorEvent = buildFakeSseMessage( + "mid-stream-error", + err.message, + req + ); res.write(`data: ${JSON.stringify(fakeErrorEvent)}\n\n`); res.write("data: [DONE]\n\n"); res.end(); @@ -97,8 +169,57 @@ export const handleStreamedResponse: RawResponseBodyHandler = async ( }); }; +/** + * Transforms SSE events from the given response API into events compatible with + * the API requested by the client. + */ +function transformEvent( + data: string, + requestApi: string, + responseApi: string, + lastPosition: number +) { + if (requestApi === responseApi) { + return { position: -1, event: data }; + } + + if (requestApi === "anthropic" && responseApi === "openai") { + throw new Error(`Anthropic -> OpenAI streaming not implemented.`); + } + + // Anthropic sends the full completion so far with each event whereas OpenAI + // only sends the delta. To make the SSE events compatible, we remove + // everything before `lastPosition` from the completion. + if (!data.startsWith("data:")) { + return { position: lastPosition, event: data }; + } + + if (data.startsWith("data: [DONE]")) { + return { position: lastPosition, event: data }; + } + + const event = JSON.parse(data.slice("data: ".length)); + const newEvent = { + id: "ant-" + event.log_id, + object: "chat.completion.chunk", + created: Date.now(), + model: event.model, + choices: [ + { + index: 0, + delta: { content: event.completion?.slice(lastPosition) }, + finish_reason: event.stop_reason, + }, + ], + }; + return { + position: event.completion.length, + event: `data: ${JSON.stringify(newEvent)}`, + }; +} + /** Copy headers, excluding ones we're already setting for the SSE response. */ -const copyHeaders = (proxyRes: http.IncomingMessage, res: Response) => { +function copyHeaders(proxyRes: http.IncomingMessage, res: Response) { const toOmit = [ "content-length", "content-encoding", @@ -112,66 +233,63 @@ const copyHeaders = (proxyRes: http.IncomingMessage, res: Response) => { res.setHeader(key, value); } } -}; +} -type OpenAiChatCompletionResponse = { - id: string; - object: string; - created: number; - model: string; - choices: { - message: { role: string; content: string }; - finish_reason: string | null; - index: number; - }[]; -}; +function convertEventsToFinalResponse(events: string[], req: Request) { + if (req.key!.service === "openai") { + let response: OpenAiChatCompletionResponse = { + id: "", + object: "", + created: 0, + model: "", + choices: [], + }; + response = events.reduce((acc, event, i) => { + if (!event.startsWith("data: ")) { + return acc; + } -/** Converts the event stream chunks into a single completion response. */ -const convertEventsToOpenAiResponse = (chunks: Buffer[]) => { - let response: OpenAiChatCompletionResponse = { - id: "", - object: "", - created: 0, - model: "", - choices: [], - }; - const events = Buffer.concat(chunks) - .toString() - .trim() - .split("\n\n") - .map((line) => line.trim()); + if (event === "data: [DONE]") { + return acc; + } - response = events.reduce((acc, chunk, i) => { - if (!chunk.startsWith("data: ")) { + const data = JSON.parse(event.slice("data: ".length)); + if (i === 0) { + return { + id: data.id, + object: data.object, + created: data.created, + model: data.model, + choices: [ + { + message: { role: data.choices[0].delta.role, content: "" }, + index: 0, + finish_reason: null, + }, + ], + }; + } + + if (data.choices[0].delta.content) { + acc.choices[0].message.content += data.choices[0].delta.content; + } + acc.choices[0].finish_reason = data.choices[0].finish_reason; return acc; - } - - if (chunk === "data: [DONE]") { - return acc; - } - - const data = JSON.parse(chunk.slice("data: ".length)); - if (i === 0) { - return { - id: data.id, - object: data.object, - created: data.created, - model: data.model, - choices: [ - { - message: { role: data.choices[0].delta.role, content: "" }, - index: 0, - finish_reason: null, - }, - ], - }; - } - - if (data.choices[0].delta.content) { - acc.choices[0].message.content += data.choices[0].delta.content; - } - acc.choices[0].finish_reason = data.choices[0].finish_reason; - return acc; - }, response); - return response; -}; + }, response); + return response; + } + if (req.key!.service === "anthropic") { + /* + * Full complete responses from Anthropic are conveniently just the same as + * the final SSE event before the "DONE" event, so we can reuse that + */ + const lastEvent = events[events.length - 2].toString(); + const data = JSON.parse(lastEvent.slice("data: ".length)); + const response: AnthropicCompletionResponse = { + ...data, + log_id: req.id, + }; + return response; + } + throw new Error("If you get this, something is fucked"); +} diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index d4a7bbb..8cc1ceb 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -1,17 +1,19 @@ +/* This file is fucking horrendous, sorry */ import { Request, Response } from "express"; import * as http from "http"; +import * as httpProxy from "http-proxy"; import util from "util"; import zlib from "zlib"; -import * as httpProxy from "http-proxy"; +import { ZodError } from "zod"; import { config } from "../../../config"; import { logger } from "../../../logger"; import { keyPool } from "../../../key-management"; +import { incrementPromptCount } from "../../auth/user-store"; import { buildFakeSseMessage, enqueue, trackWaitTime } from "../../queue"; +import { isCompletionRequest } from "../request"; import { handleStreamedResponse } from "./handle-streamed-response"; import { logPrompt } from "./log-prompt"; -import { incrementPromptCount } from "../../auth/user-store"; -export const QUOTA_ROUTES = ["/v1/chat/completions"]; const DECODER_MAP = { gzip: util.promisify(zlib.gunzip), deflate: util.promisify(zlib.inflate), @@ -174,7 +176,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( } else { const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`; logger.warn({ contentEncoding, key: req.key?.hash }, errorMessage); - writeErrorResponse(res, 500, { + writeErrorResponse(req, res, 500, { error: errorMessage, contentEncoding, }); @@ -191,7 +193,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( } catch (error: any) { const errorMessage = `Proxy received response with invalid JSON: ${error.message}`; logger.warn({ error, key: req.key?.hash }, errorMessage); - writeErrorResponse(res, 500, { error: errorMessage }); + writeErrorResponse(req, res, 500, { error: errorMessage }); return reject(errorMessage); } }); @@ -199,8 +201,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async ( return promise; }; -// TODO: This is too specific to OpenAI's error responses, Anthropic errors -// will need a different handler. +// TODO: This is too specific to OpenAI's error responses. /** * Handles non-2xx responses from the upstream service. If the proxied response * is an error, this will respond to the client with an error payload and throw @@ -237,7 +238,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( } } catch (parseError: any) { const statusMessage = proxyRes.statusMessage || "Unknown error"; - // Likely Bad Gateway or Gateway Timeout from OpenAI's Cloudflare proxy + // Likely Bad Gateway or Gateway Timeout from reverse proxy/load balancer logger.warn( { statusCode, statusMessage, key: req.key?.hash }, parseError.message @@ -249,7 +250,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( error: parseError.message, proxy_note: `This is likely a temporary error with the upstream service.`, }; - writeErrorResponse(res, statusCode, errorObject); + writeErrorResponse(req, res, statusCode, errorObject); throw new Error(parseError.message); } @@ -265,47 +266,35 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( if (statusCode === 400) { // Bad request (likely prompt is too long) - errorPayload.proxy_note = `OpenAI rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`; + errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`; } else if (statusCode === 401) { // Key is invalid or was revoked keyPool.disable(req.key!); - errorPayload.proxy_note = `The OpenAI key is invalid or revoked. ${tryAgainMessage}`; + errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`; } else if (statusCode === 429) { - const type = errorPayload.error?.type; - if (type === "insufficient_quota") { - // Billing quota exceeded (key is dead, disable it) - keyPool.disable(req.key!); - errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`; - } else if (type === "billing_not_active") { - // Billing is not active (key is dead, disable it) - keyPool.disable(req.key!); - errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`; - } else if (type === "requests" || type === "tokens") { - // Per-minute request or token rate limit is exceeded, which we can retry - keyPool.markRateLimited(req.key!.hash); - if (config.queueMode !== "none") { - reenqueueRequest(req); - // TODO: I don't like using an error to control flow here - throw new RetryableError("Rate-limited request re-enqueued."); - } - errorPayload.proxy_note = `Assigned key's '${type}' rate limit has been exceeded. Try again later.`; + // OpenAI uses this for a bunch of different rate-limiting scenarios. + if (req.key!.service === "openai") { + handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload); } else { - // OpenAI probably overloaded - errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`; + handleAnthropicRateLimitError(req, errorPayload); } } else if (statusCode === 404) { // Most likely model not found - // TODO: this probably doesn't handle GPT-4-32k variants properly if the - // proxy has keys for both the 8k and 32k context models at the same time. - if (errorPayload.error?.code === "model_not_found") { - if (req.key!.isGpt4) { - errorPayload.proxy_note = `Assigned key isn't provisioned for the GPT-4 snapshot you requested. Try again to get a different key, or use Turbo.`; - } else { - errorPayload.proxy_note = `No model was found for this key.`; + if (req.key!.service === "openai") { + // TODO: this probably doesn't handle GPT-4-32k variants properly if the + // proxy has keys for both the 8k and 32k context models at the same time. + if (errorPayload.error?.code === "model_not_found") { + if (req.key!.isGpt4) { + errorPayload.proxy_note = `Assigned key isn't provisioned for the GPT-4 snapshot you requested. Try again to get a different key, or use Turbo.`; + } else { + errorPayload.proxy_note = `No model was found for this key.`; + } } + } else if (req.key!.service === "anthropic") { + errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`; } } else { - errorPayload.proxy_note = `Unrecognized error from OpenAI.`; + errorPayload.proxy_note = `Unrecognized error from upstream service.`; } // Some OAI errors contain the organization ID, which we don't want to reveal. @@ -316,15 +305,68 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( ); } - writeErrorResponse(res, statusCode, errorPayload); + writeErrorResponse(req, res, statusCode, errorPayload); throw new Error(errorPayload.error?.message); }; +function handleAnthropicRateLimitError( + req: Request, + errorPayload: Record +) { + //{"error":{"type":"rate_limit_error","message":"Number of concurrent connections to Claude exceeds your rate limit. Please try again, or contact sales@anthropic.com to discuss your options for a rate limit increase."}} + if (errorPayload.error?.type === "rate_limit_error") { + keyPool.markRateLimited(req.key!); + if (config.queueMode !== "none") { + reenqueueRequest(req); + throw new RetryableError("Claude rate-limited request re-enqueued."); + } + errorPayload.proxy_note = `There are too many in-flight requests for this key. Try again later.`; + } else { + errorPayload.proxy_note = `Unrecognized rate limit error from Anthropic. Key may be over quota.`; + } +} + +function handleOpenAIRateLimitError( + req: Request, + tryAgainMessage: string, + errorPayload: Record +): Record { + const type = errorPayload.error?.type; + if (type === "insufficient_quota") { + // Billing quota exceeded (key is dead, disable it) + keyPool.disable(req.key!); + errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`; + } else if (type === "billing_not_active") { + // Billing is not active (key is dead, disable it) + keyPool.disable(req.key!); + errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`; + } else if (type === "requests" || type === "tokens") { + // Per-minute request or token rate limit is exceeded, which we can retry + keyPool.markRateLimited(req.key!); + if (config.queueMode !== "none") { + reenqueueRequest(req); + // This is confusing, but it will bubble up to the top-level response + // handler and cause the request to go back into the request queue. + throw new RetryableError("Rate-limited request re-enqueued."); + } + errorPayload.proxy_note = `Assigned key's '${type}' rate limit has been exceeded. Try again later.`; + } else { + // OpenAI probably overloaded + errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`; + } + return errorPayload; +} + function writeErrorResponse( + req: Request, res: Response, statusCode: number, errorPayload: Record ) { + const errorSource = errorPayload.error?.type.startsWith("proxy") + ? "proxy" + : "upstream"; + // If we're mid-SSE stream, send a data event with the error payload and end // the stream. Otherwise just send a normal error response. if ( @@ -332,8 +374,9 @@ function writeErrorResponse( res.getHeader("content-type") === "text/event-stream" ) { const msg = buildFakeSseMessage( - `upstream error (${statusCode})`, - JSON.stringify(errorPayload, null, 2) + `${errorSource} error (${statusCode})`, + JSON.stringify(errorPayload, null, 2), + req ); res.write(msg); res.write(`data: [DONE]\n\n`); @@ -344,21 +387,31 @@ function writeErrorResponse( } /** Handles errors in rewriter pipelines. */ -export const handleInternalError: httpProxy.ErrorCallback = ( - err, - _req, - res -) => { +export const handleInternalError: httpProxy.ErrorCallback = (err, req, res) => { logger.error({ error: err }, "Error in http-proxy-middleware pipeline."); + try { - writeErrorResponse(res as Response, 500, { - error: { - type: "proxy_error", - message: err.message, - stack: err.stack, - proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`, - }, - }); + const isZod = err instanceof ZodError; + if (isZod) { + writeErrorResponse(req as Request, res as Response, 400, { + error: { + type: "proxy_validation_error", + proxy_note: `Reverse proxy couldn't validate your request when trying to transform it. Your client may be sending invalid data.`, + issues: err.issues, + stack: err.stack, + message: err.message, + }, + }); + } else { + writeErrorResponse(req as Request, res as Response, 500, { + error: { + type: "proxy_rewriter_error", + proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`, + message: err.message, + stack: err.stack, + }, + }); + } } catch (e) { logger.error( { error: e }, @@ -368,8 +421,8 @@ export const handleInternalError: httpProxy.ErrorCallback = ( }; const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => { - if (QUOTA_ROUTES.includes(req.path)) { - keyPool.incrementPrompt(req.key?.hash); + if (isCompletionRequest(req)) { + keyPool.incrementPrompt(req.key!); if (req.user) { incrementPromptCount(req.user.token); } @@ -377,7 +430,7 @@ const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => { }; const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => { - keyPool.updateRateLimits(req.key!.hash, proxyRes.headers); + keyPool.updateRateLimits(req.key!, proxyRes.headers); }; const copyHttpHeaders: ProxyResHandlerWithBody = async ( diff --git a/src/proxy/middleware/response/log-prompt.ts b/src/proxy/middleware/response/log-prompt.ts index 32ec1af..751f418 100644 --- a/src/proxy/middleware/response/log-prompt.ts +++ b/src/proxy/middleware/response/log-prompt.ts @@ -1,4 +1,5 @@ import { config } from "../../../config"; +import { AIService } from "../../../key-management"; import { logQueue } from "../../../prompt-logging"; import { isCompletionRequest } from "../request"; import { ProxyResHandlerWithBody } from "."; @@ -17,18 +18,16 @@ export const logPrompt: ProxyResHandlerWithBody = async ( throw new Error("Expected body to be an object"); } - // Only log prompts if we're making a request to a completion endpoint if (!isCompletionRequest(req)) { - // Remove this once we're confident that we're not missing any prompts - req.log.info( - `Not logging prompt for ${req.path} because it's not a completion endpoint` - ); return; } const model = req.body.model; const promptFlattened = flattenMessages(req.body.messages); - const response = getResponseForModel({ model, body: responseBody }); + const response = getResponseForService({ + service: req.key!.service, + body: responseBody, + }); logQueue.enqueue({ model, @@ -48,15 +47,14 @@ const flattenMessages = (messages: OaiMessage[]): string => { return messages.map((m) => `${m.role}: ${m.content}`).join("\n"); }; -const getResponseForModel = ({ - model, +const getResponseForService = ({ + service, body, }: { - model: string; + service: AIService; body: Record; }) => { - if (model.startsWith("claude")) { - // TODO: confirm if there is supposed to be a leading space + if (service === "anthropic") { return body.completion.trim(); } else { return body.choices[0].message.content; diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 780a56f..080677e 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -8,10 +8,10 @@ import { ipLimiter } from "./rate-limit"; import { addKey, languageFilter, - checkStreaming, finalizeBody, limitOutputTokens, limitCompletions, + transformOutboundPayload, } from "./middleware/request"; import { createOnProxyResHandler, @@ -28,9 +28,9 @@ const rewriteRequest = ( const rewriterPipeline = [ addKey, languageFilter, - checkStreaming, limitOutputTokens, limitCompletions, + transformOutboundPayload, finalizeBody, ]; @@ -39,7 +39,7 @@ const rewriteRequest = ( rewriter(proxyReq, req, res, {}); } } catch (error) { - logger.error(error, "Error while executing proxy rewriter"); + req.log.error(error, "Error while executing proxy rewriter"); proxyReq.destroy(error as Error); } }; @@ -98,7 +98,7 @@ openaiRouter.get("*", (req, res, next) => { } }); openaiRouter.use((req, res) => { - logger.warn(`Blocked openai proxy request: ${req.method} ${req.path}`); + req.log.warn(`Blocked openai proxy request: ${req.method} ${req.path}`); res.status(404).json({ error: "Not found" }); }); diff --git a/src/proxy/queue.ts b/src/proxy/queue.ts index b15e2dd..dc8aa46 100644 --- a/src/proxy/queue.ts +++ b/src/proxy/queue.ts @@ -17,7 +17,7 @@ import type { Handler, Request } from "express"; import { config, DequeueMode } from "../config"; -import { keyPool } from "../key-management"; +import { keyPool, SupportedModel } from "../key-management"; import { logger } from "../logger"; import { AGNAI_DOT_CHAT_IP } from "./rate-limit"; @@ -78,7 +78,7 @@ export function enqueue(req: Request) { // If the request opted into streaming, we need to register a heartbeat // handler to keep the connection alive while it waits in the queue. We // deregister the handler when the request is dequeued. - if (req.body.stream) { + if (req.body.stream === "true" || req.body.stream === true) { const res = req.res!; if (!res.headersSent) { initStreaming(req); @@ -91,7 +91,7 @@ export function enqueue(req: Request) { const avgWait = Math.round(getEstimatedWaitTime() / 1000); const currentDuration = Math.round((Date.now() - req.startTime) / 1000); const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`; - req.res!.write(buildFakeSseMessage("heartbeat", debugMsg)); + req.res!.write(buildFakeSseMessage("heartbeat", debugMsg, req)); } }, 10000); } @@ -118,12 +118,24 @@ export function enqueue(req: Request) { } } -export function dequeue(model: string): Request | undefined { - // TODO: This should be set by some middleware that checks the request body. - const modelQueue = - model === "gpt-4" - ? queue.filter((req) => req.body.model?.startsWith("gpt-4")) - : queue.filter((req) => !req.body.model?.startsWith("gpt-4")); +export function dequeue(model: SupportedModel): Request | undefined { + const modelQueue = queue.filter((req) => { + const reqProvider = req.originalUrl.startsWith("/proxy/anthropic") + ? "anthropic" + : "openai"; + + // This sucks, but the `req.body.model` on Anthropic requests via the + // OpenAI-compat endpoint isn't actually claude-*, it's a fake gpt value. + // TODO: refactor model/service detection + + if (model.startsWith("claude")) { + return reqProvider === "anthropic"; + } + if (model.startsWith("gpt-4")) { + return reqProvider === "openai" && req.body.model?.startsWith("gpt-4"); + } + return reqProvider === "openai" && req.body.model?.startsWith("gpt-3"); + }); if (modelQueue.length === 0) { return undefined; @@ -172,6 +184,7 @@ function processQueue() { // the others, because we only track one rate limit per key. const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4"); const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo"); + const claudeLockout = keyPool.getLockoutPeriod("claude-v1"); const reqs: (Request | undefined)[] = []; if (gpt4Lockout === 0) { @@ -180,6 +193,9 @@ function processQueue() { if (turboLockout === 0) { reqs.push(dequeue("gpt-3.5-turbo")); } + if (claudeLockout === 0) { + reqs.push(dequeue("claude-v1")); + } reqs.filter(Boolean).forEach((req) => { if (req?.proceed) { @@ -266,7 +282,7 @@ export function createQueueMiddleware(proxyMiddleware: Handler): Handler { type: "proxy_error", message: err.message, stack: err.stack, - proxy_note: `Only one request per IP can be queued at a time. If you don't have another request queued, your IP may be in use by another user.`, + proxy_note: `Only one request can be queued at a time. If you don't have another request queued, your IP or user token might be in use by another request.`, }); } }; @@ -281,7 +297,11 @@ function killQueuedRequest(req: Request) { try { const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`; if (res.headersSent) { - const fakeErrorEvent = buildFakeSseMessage("proxy queue error", message); + const fakeErrorEvent = buildFakeSseMessage( + "proxy queue error", + message, + req + ); res.write(fakeErrorEvent); res.end(); } else { @@ -305,20 +325,38 @@ function initStreaming(req: Request) { res.write(": joining queue\n\n"); } -export function buildFakeSseMessage(type: string, string: string) { - const fakeEvent = { - id: "chatcmpl-" + type, - object: "chat.completion.chunk", - created: Date.now(), - model: "", - choices: [ - { - delta: { content: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n` }, - index: 0, - finish_reason: type, - }, - ], - }; +export function buildFakeSseMessage( + type: string, + string: string, + req: Request +) { + let fakeEvent; + + if (req.api === "anthropic") { + // data: {"completion": " Here is a paragraph of lorem ipsum text:\n\nLorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor inc", "stop_reason": "max_tokens", "truncated": false, "stop": null, "model": "claude-instant-v1", "log_id": "???", "exception": null} + fakeEvent = { + completion: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`, + stop_reason: type, + truncated: false, // I've never seen this be true + stop: null, + model: req.body?.model, + log_id: "proxy-req-" + req.id, + }; + } else { + fakeEvent = { + id: "chatcmpl-" + req.id, + object: "chat.completion.chunk", + created: Date.now(), + model: req.body?.model, + choices: [ + { + delta: { content: `\`\`\`\n[${type}: ${string}]\n\`\`\`\n` }, + index: 0, + finish_reason: type, + }, + ], + }; + } return `data: ${JSON.stringify(fakeEvent)}\n\n`; } diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index 3edfde1..97e1e51 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -8,12 +8,14 @@ import * as express from "express"; import { gatekeeper } from "./auth/gatekeeper"; import { kobold } from "./kobold"; import { openai } from "./openai"; +import { anthropic } from "./anthropic"; const router = express.Router(); router.use(gatekeeper); router.use("/kobold", kobold); router.use("/openai", openai); +router.use("/anthropic", anthropic); // Each client handles the endpoints input by the user in slightly different // ways, eg TavernAI ignores everything after the hostname in Kobold mode diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index b67b033..4535ce1 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -1,11 +1,16 @@ import { Express } from "express-serve-static-core"; -import { Key } from "../key-management/key-pool"; +import { Key } from "../key-management/index"; import { User } from "../proxy/auth/user-store"; declare global { namespace Express { interface Request { key?: Key; + /** + * Denotes the _inbound_ API format. This is used to determine how the + * user has submitted their request; the proxy will then translate the + * paramaters to the target API format, which is on `key.service`. + */ api: "kobold" | "openai" | "anthropic"; user?: User; isStreaming?: boolean;