This commit is contained in:
khanon
2023-08-29 22:56:54 +00:00
parent 3c56103de0
commit 4d781e1720
14 changed files with 285 additions and 100 deletions
+1 -1
View File
@@ -10,7 +10,7 @@
# REJECT_DISALLOWED=false
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
# CHECK_KEYS=true
# TURBO_ONLY=false
# ALLOWED_MODEL_FAMILIES=claude,turbo,gpt4,gpt4-32k
# BLOCKED_ORIGINS=reddit.com,9gag.com
# BLOCK_MESSAGE="You must be over the age of majority in your country to use this service."
# BLOCK_REDIRECT="https://roblox.com/"
+25 -7
View File
@@ -1,6 +1,7 @@
import dotenv from "dotenv";
import type firebase from "firebase-admin";
import pino from "pino";
import type { ModelFamily } from "./key-management/models";
dotenv.config();
// Can't import the usual logger here because it itself needs the config.
@@ -112,11 +113,8 @@ type Config = {
* Desination URL to redirect blocked requests to, for non-JSON requests.
*/
blockRedirect?: string;
/**
* Whether the proxy should disallow requests for GPT-4 models in order to
* prevent excessive spend. Applies only to OpenAI.
*/
turboOnly?: boolean;
/** Which model families to allow requests for. Applies only to OpenAI. */
allowedModelFamilies: ModelFamily[];
/**
* The number of (LLM) tokens a user can consume before requests are rejected.
* Limits include both prompt and response tokens. `special` users are exempt.
@@ -170,6 +168,12 @@ export const config: Config = {
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
400
),
allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [
"turbo",
"gpt4",
"gpt4-32k",
"claude",
]),
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
rejectMessage: getEnvWithDefault(
"REJECT_MESSAGE",
@@ -190,7 +194,6 @@ export const config: Config = {
"You must be over the age of majority in your country to use this service."
),
blockRedirect: getEnvWithDefault("BLOCK_REDIRECT", "https://www.9gag.com"),
turboOnly: getEnvWithDefault("TURBO_ONLY", false),
tokenQuota: {
turbo: getEnvWithDefault("TOKEN_QUOTA_TURBO", 0),
gpt4: getEnvWithDefault("TOKEN_QUOTA_GPT4", 0),
@@ -200,6 +203,15 @@ export const config: Config = {
} as const;
export async function assertConfigIsValid() {
if (process.env.TURBO_ONLY === "true") {
startupLogger.warn(
"TURBO_ONLY is deprecated. Use ALLOWED_MODEL_FAMILIES=turbo instead."
);
config.allowedModelFamilies = config.allowedModelFamilies.filter(
(f) => !f.includes("gpt4")
);
}
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
throw new Error(
`Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.`
@@ -298,7 +310,7 @@ export function listConfig(obj: Config = config): Record<string, any> {
result[key] = value;
}
if (typeof obj[key] === "object") {
if (typeof obj[key] === "object" && !Array.isArray(obj[key])) {
result[key] = listConfig(obj[key] as unknown as Config);
}
}
@@ -320,6 +332,12 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
if (env === "OPENAI_KEY" || env === "ANTHROPIC_KEY") {
return value as unknown as T;
}
// Intended to be used for comma-delimited lists
if (Array.isArray(defaultValue)) {
return value.split(",").map((v) => v.trim()) as T;
}
return JSON.parse(value) as T;
} catch (err) {
return value as unknown as T;
+86 -24
View File
@@ -2,13 +2,14 @@ import fs from "fs";
import { Request, Response } from "express";
import showdown from "showdown";
import { config, listConfig } from "./config";
import { OpenAIKey, keyPool } from "./key-management";
import { getUniqueIps } from "./proxy/rate-limit";
import {
QueuePartition,
getEstimatedWaitTime,
getQueueLength,
} from "./proxy/queue";
ModelFamily,
OpenAIKey,
OpenAIModelFamily,
keyPool,
} from "./key-management";
import { getUniqueIps } from "./proxy/rate-limit";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
const INFO_PAGE_TTL = 5000;
let infoPageHtml: string | undefined;
@@ -78,34 +79,65 @@ function cacheInfoPageHtml(baseUrl: string) {
type ServiceInfo = {
activeKeys: number;
trialKeys?: number;
// activeLimit: string;
revokedKeys?: number;
overQuotaKeys?: number;
proomptersInQueue: number;
estimatedQueueTime: string;
};
// this has long since outgrown this awful "dump everything in a <pre> tag" approach
// but I really don't want to spend time on a proper UI for this right now
/**
* This may end up doing a very very large number of wasted iterations over
* potentially large key lists, don't call it too often.
*/
function getOpenAIInfo() {
const info: { [model: string]: Partial<ServiceInfo> } = {};
const keys = keyPool
.list()
.filter((k) => k.service === "openai") as OpenAIKey[];
const hasGpt4 = keys.some((k) => k.isGpt4) && !config.turboOnly;
const allowed = new Set(config.allowedModelFamilies);
let available = new Set<OpenAIModelFamily>();
const keys = keyPool.list().filter((k) => {
if (k.service === "openai") {
k.modelFamilies.forEach((f) => available.add(f as OpenAIModelFamily));
return true;
}
return false;
}) as Omit<OpenAIKey, "key">[];
available = new Set([...available].filter((f) => allowed.has(f)));
if (keyPool.anyUnchecked()) {
const uncheckedKeys = keys.filter((k) => !k.lastChecked);
info.status =
`Performing startup key checks (${uncheckedKeys.length} left).` as any;
`Performing key checks (${uncheckedKeys.length} left).` as any;
} else {
delete info.status;
}
if (config.checkKeys) {
const turboKeys = keys.filter((k) => !k.isGpt4);
const gpt4Keys = keys.filter((k) => k.isGpt4);
const keysByModel = keys.reduce(
(acc, k) => {
// only put keys in the most important family they belong to.
// if a model family is disabled, key will be in the next most
// important family.
if (k.modelFamilies.includes("gpt4-32k") && allowed.has("gpt4-32k")) {
acc["gpt4-32k"].push(k);
} else if (k.modelFamilies.includes("gpt4") && allowed.has("gpt4")) {
acc["gpt4"].push(k);
} else {
acc["turbo"].push(k);
}
return acc;
},
{ turbo: [], gpt4: [], "gpt4-32k": [] } as Record<
OpenAIModelFamily,
Omit<OpenAIKey, "key">[]
>
);
const turboKeys = keysByModel["turbo"];
const gpt4Keys = keysByModel["gpt4"];
const gpt432kKeys = keysByModel["gpt4-32k"];
// this is fucked
info.turbo = {
activeKeys: turboKeys.filter((k) => !k.isDisabled).length,
@@ -114,7 +146,7 @@ function getOpenAIInfo() {
overQuotaKeys: turboKeys.filter((k) => k.isOverQuota).length,
};
if (hasGpt4) {
if (available.has("gpt4")) {
info.gpt4 = {
activeKeys: gpt4Keys.filter((k) => !k.isDisabled).length,
trialKeys: gpt4Keys.filter((k) => k.isTrial).length,
@@ -122,11 +154,22 @@ function getOpenAIInfo() {
overQuotaKeys: gpt4Keys.filter((k) => k.isOverQuota).length,
};
}
if (available.has("gpt4-32k")) {
info["gpt4-32k"] = {
activeKeys: gpt432kKeys.filter((k) => !k.isDisabled).length,
trialKeys: gpt432kKeys.filter((k) => k.isTrial).length,
revokedKeys: gpt432kKeys.filter((k) => k.isRevoked).length,
overQuotaKeys: gpt432kKeys.filter((k) => k.isOverQuota).length,
};
}
} else {
info.status = "Key checking is disabled." as any;
info.turbo = { activeKeys: keys.filter((k) => !k.isDisabled).length };
info.gpt4 = {
activeKeys: keys.filter((k) => !k.isDisabled && k.isGpt4).length,
activeKeys: keys.filter(
(k) => !k.isDisabled && k.modelFamilies.includes("gpt4")
).length,
};
}
@@ -135,12 +178,18 @@ function getOpenAIInfo() {
info.turbo.proomptersInQueue = turboQueue.proomptersInQueue;
info.turbo.estimatedQueueTime = turboQueue.estimatedQueueTime;
if (hasGpt4) {
const gpt4Queue = getQueueInformation("gpt-4");
if (available.has("gpt4")) {
const gpt4Queue = getQueueInformation("gpt4");
info.gpt4.proomptersInQueue = gpt4Queue.proomptersInQueue;
info.gpt4.estimatedQueueTime = gpt4Queue.estimatedQueueTime;
}
if (available.has("gpt4-32k")) {
const gpt432kQueue = getQueueInformation("gpt4-32k");
info["gpt4-32k"].proomptersInQueue = gpt432kQueue.proomptersInQueue;
info["gpt4-32k"].estimatedQueueTime = gpt432kQueue.estimatedQueueTime;
}
return info;
}
@@ -180,12 +229,25 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t
infoBody += `\n## Estimated Wait Times\nIf the AI is busy, your prompt will processed when a slot frees up.`;
if (config.openaiKey) {
// this is also fucked
const keys = keyPool.list().filter((k) => k.service === "openai");
const turboWait = getQueueInformation("turbo").estimatedQueueTime;
const gpt4Wait = getQueueInformation("gpt-4").estimatedQueueTime;
waits.push(`**Turbo:** ${turboWait}`);
if (keyPool.list().some((k) => k.isGpt4) && !config.turboOnly) {
const gpt4Wait = getQueueInformation("gpt4").estimatedQueueTime;
const hasGpt4 = keys.some((k) => k.modelFamilies.includes("gpt4"));
const allowedGpt4 = config.allowedModelFamilies.includes("gpt4");
if (hasGpt4 && allowedGpt4) {
waits.push(`**GPT-4:** ${gpt4Wait}`);
}
const gpt432kWait = getQueueInformation("gpt4-32k").estimatedQueueTime;
const hasGpt432k = keys.some((k) => k.modelFamilies.includes("gpt4-32k"));
const allowedGpt432k = config.allowedModelFamilies.includes("gpt4-32k");
if (hasGpt432k && allowedGpt432k) {
waits.push(`**GPT-4-32k:** ${gpt432kWait}`);
}
}
if (config.anthropicKey) {
@@ -202,7 +264,7 @@ ${customGreeting}`;
}
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
function getQueueInformation(partition: QueuePartition) {
function getQueueInformation(partition: ModelFamily) {
const waitMs = getEstimatedWaitTime(partition);
const waitTime =
waitMs < 60000
+3 -1
View File
@@ -2,6 +2,7 @@ import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../config";
import { logger } from "../../logger";
import type { AnthropicModelFamily } from "../models";
// https://docs.anthropic.com/claude/reference/selecting-a-model
export const ANTHROPIC_SUPPORTED_MODELS = [
@@ -25,6 +26,7 @@ export type AnthropicKeyUpdate = Omit<
export interface AnthropicKey extends Key {
readonly service: "anthropic";
readonly modelFamilies: AnthropicModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
@@ -71,7 +73,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
const newKey: AnthropicKey = {
key,
service: this.service,
isGpt4: false,
modelFamilies: ["claude"],
isTrial: false,
isDisabled: false,
promptCount: 0,
+9 -2
View File
@@ -4,6 +4,7 @@ import {
AnthropicModel,
} from "./anthropic/provider";
import { KeyPool } from "./key-pool";
import type { ModelFamily } from "./models";
export type AIService = "openai" | "anthropic";
export type Model = OpenAIModel | AnthropicModel;
@@ -15,8 +16,8 @@ export interface Key {
service: AIService;
/** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */
isTrial: boolean;
/** Whether this key has been provisioned for GPT-4. */
isGpt4: boolean;
/** The model families that this key has access to. */
modelFamilies: ModelFamily[];
/** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */
isDisabled: boolean;
/** The number of prompts that have been sent with this key. */
@@ -65,3 +66,9 @@ export type SupportedModel = (typeof SUPPORTED_MODELS)[number];
export { OPENAI_SUPPORTED_MODELS, ANTHROPIC_SUPPORTED_MODELS };
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export type {
OpenAIModelFamily,
AnthropicModelFamily,
ModelFamily,
} from "./models";
export { getOpenAIModelFamily, getClaudeModelFamily } from "./models";
+6
View File
@@ -3,6 +3,7 @@ import schedule from "node-schedule";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { Key, Model, KeyProvider, AIService } from "./index";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { config } from "../config";
import { logger } from "../logger";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@@ -89,6 +90,11 @@ export class KeyPool {
}
public recheck(service: AIService): void {
if (!config.checkKeys) {
logger.info("Skipping key recheck because key checking is disabled");
return;
}
const provider = this.getKeyProvider(service);
if (provider instanceof OpenAIKeyProvider) {
provider.recheck();
+27
View File
@@ -0,0 +1,27 @@
import { logger } from "../logger";
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k";
export type AnthropicModelFamily = "claude";
export type ModelFamily = OpenAIModelFamily | AnthropicModelFamily;
export type ModelFamilyMap = { [regex: string]: ModelFamily };
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
"^gpt-4-32k$": "gpt4-32k",
"^gpt-4-\\d{4}$": "gpt4",
"^gpt-4$": "gpt4",
"^gpt-3.5-turbo": "turbo",
};
export function getOpenAIModelFamily(model: string): OpenAIModelFamily {
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
if (model.match(regex)) return family;
}
const stack = new Error().stack;
logger.warn({ model, stack }, "Unmapped model family");
return "gpt4";
}
export function getClaudeModelFamily(_model: string): ModelFamily {
return "claude";
}
+29 -11
View File
@@ -1,6 +1,7 @@
import axios, { AxiosError } from "axios";
import { logger } from "../../logger";
import type { OpenAIKey, OpenAIKeyProvider } from "./provider";
import type { OpenAIModelFamily } from "../models";
/** Minimum time in between any two key checks. */
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
@@ -136,7 +137,7 @@ export class OpenAIKeyChecker {
this.maybeCreateOrganizationClones(key),
]);
const updates = {
isGpt4: provisionedModels.gpt4,
modelFamilies: provisionedModels,
isTrial: livenessTest.rateLimit <= 250,
softLimit: 0,
hardLimit: 0,
@@ -149,7 +150,10 @@ export class OpenAIKeyChecker {
const updates = { softLimit: 0, hardLimit: 0, systemHardLimit: 0 };
this.updateKey(key.hash, updates);
}
this.log.info({ key: key.hash }, "Key check complete.");
this.log.info(
{ key: key.hash, models: key.modelFamilies },
"Key check complete."
);
} catch (error) {
// touch the key so we don't check it again for a while
this.updateKey(key.hash, {});
@@ -166,23 +170,35 @@ export class OpenAIKeyChecker {
private async getProvisionedModels(
key: OpenAIKey
): Promise<{ turbo: boolean; gpt4: boolean }> {
): Promise<OpenAIModelFamily[]> {
const opts = { headers: OpenAIKeyChecker.getHeaders(key) };
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
const models = data.data;
const turbo = models.some(({ id }) => id.startsWith("gpt-3.5"));
const gpt4 = models.some(({ id }) => id.startsWith("gpt-4"));
// We want to update the key's `isGpt4` flag here, but we don't want to
const families: OpenAIModelFamily[] = [];
if (models.some(({ id }) => id.startsWith("gpt-3.5-turbo"))) {
families.push("turbo");
}
if (models.some(({ id }) => id.startsWith("gpt-4"))) {
families.push("gpt4");
}
if (models.some(({ id }) => id.startsWith("gpt-4-32k"))) {
families.push("gpt4-32k");
}
// We want to update the key's model families here, but we don't want to
// update its `lastChecked` timestamp because we need to let the liveness
// check run before we can consider the key checked.
// Need to use `find` here because keys are cloned from the pool.
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
this.updateKey(key.hash, {
isGpt4: gpt4,
modelFamilies: families,
lastChecked: keyFromPool.lastChecked,
});
return { turbo, gpt4 };
return families;
}
private async maybeCreateOrganizationClones(key: OpenAIKey) {
@@ -219,7 +235,7 @@ export class OpenAIKeyChecker {
this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
isGpt4: false,
modelFamilies: ["turbo"],
});
} else if (status === 429) {
switch (data.error.type) {
@@ -228,7 +244,9 @@ export class OpenAIKeyChecker {
case "billing_not_active":
const isOverQuota = data.error.type === "insufficient_quota";
const isRevoked = !isOverQuota;
const isGpt4 = isRevoked ? false : key.isGpt4;
const modelFamilies: OpenAIModelFamily[] = isRevoked
? ["turbo"]
: key.modelFamilies;
this.log.warn(
{ key: key.hash, rateLimitType: data.error.type, error: data },
"Key returned a non-transient 429 error. Disabling key."
@@ -237,7 +255,7 @@ export class OpenAIKeyChecker {
isDisabled: true,
isRevoked,
isOverQuota,
isGpt4,
modelFamilies,
});
break;
case "requests":
+40 -15
View File
@@ -9,8 +9,9 @@ import { KeyProvider, Key, Model } from "../index";
import { config } from "../../config";
import { logger } from "../../logger";
import { OpenAIKeyChecker } from "./checker";
import { OpenAIModelFamily, getOpenAIModelFamily } from "../models";
export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4";
export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4" | "gpt-4-32k";
export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
"gpt-3.5-turbo",
"gpt-4",
@@ -18,6 +19,7 @@ export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
export interface OpenAIKey extends Key {
readonly service: "openai";
modelFamilies: OpenAIModelFamily[];
/**
* Some keys are assigned to multiple organizations, each with their own quota
* limits. We clone the key for each organization and track usage/disabled
@@ -85,7 +87,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
const newKey = {
key: k,
service: "openai" as const,
isGpt4: true,
modelFamilies: ["turbo" as const, "gpt4" as const],
isTrial: false,
isDisabled: false,
isRevoked: false,
@@ -134,35 +136,39 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
}
public get(model: Model) {
const needGpt4 = model.startsWith("gpt-4");
const neededFamily = getOpenAIModelFamily(model);
const availableKeys = this.keys.filter(
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
(key) => !key.isDisabled && key.modelFamilies.includes(neededFamily)
);
if (availableKeys.length === 0) {
let message = needGpt4
? "No GPT-4 keys available. Try selecting a Turbo model."
: "No active OpenAI keys available.";
throw new Error(message);
throw new Error(`No active keys available for ${neededFamily} models.`);
}
if (needGpt4 && config.turboOnly) {
if (!config.allowedModelFamilies.includes(neededFamily)) {
throw new Error(
"Proxy operator has disabled GPT-4 to reduce quota usage. Try selecting a Turbo model."
`Proxy operator has disabled access to ${neededFamily} models.`
);
}
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. We ignore rate limits from over a minute ago
// a. We ignore rate limits from >30 seconds ago
// b. If all keys were rate limited in the last minute, select the
// least recently rate limited key
// 2. Keys which are trials
// 3. Keys which have not been used in the longest time
// 3. Keys which do *not* have access to GPT-4-32k
// 4. Keys which have not been used in the longest time
const now = Date.now();
const rateLimitThreshold = 60 * 1000;
const rateLimitThreshold = 30 * 1000;
const keysByPriority = availableKeys.sort((a, b) => {
// TODO: this isn't quite right; keys are briefly artificially rate-
// limited when they are selected, so this will deprioritize keys that
// may not actually be limited, simply because they were used recently.
// This should be adjusted to use a new `rateLimitedUntil` field instead
// of `rateLimitedAt`.
const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold;
const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold;
@@ -171,13 +177,32 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
// Neither key is rate limited, continue
if (a.isTrial && !b.isTrial) return -1;
if (!a.isTrial && b.isTrial) return 1;
// Neither or both keys are trials, continue
const aHas32k = a.modelFamilies.includes("gpt4-32k");
const bHas32k = b.modelFamilies.includes("gpt4-32k");
if (aHas32k && !bHas32k) return 1;
if (!aHas32k && bHas32k) return -1;
// Neither or both keys have 32k, continue
return a.lastUsed - b.lastUsed;
});
// logger.debug(
// {
// byPriority: keysByPriority.map((k) => ({
// hash: k.hash,
// isRateLimited: now - k.rateLimitedAt < rateLimitThreshold,
// modelFamilies: k.modelFamilies,
// })),
// },
// "Keys sorted by priority"
// );
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
@@ -243,9 +268,9 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
* the request, or returns 0 if a key is ready immediately.
*/
public getLockoutPeriod(model: Model = "gpt-4"): number {
const needGpt4 = model.startsWith("gpt-4");
const neededFamily = getOpenAIModelFamily(model);
const activeKeys = this.keys.filter(
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
(key) => !key.isDisabled && key.modelFamilies.includes(neededFamily)
);
if (activeKeys.length === 0) {
+4 -4
View File
@@ -160,7 +160,7 @@ export function incrementTokenCount(
) {
const user = users.get(token);
if (!user) return;
const modelFamily = getModelFamily(model);
const modelFamily = getModelFamilyForQuotaUsage(model);
user.tokenCounts[modelFamily] += consumption;
usersToFlush.add(token);
}
@@ -197,7 +197,7 @@ export function hasAvailableQuota(
if (!user) return false;
if (user.type === "special") return true;
const modelFamily = getModelFamily(model);
const modelFamily = getModelFamilyForQuotaUsage(model);
const { tokenCounts, tokenLimits } = user;
const tokenLimit = tokenLimits[modelFamily];
@@ -281,9 +281,9 @@ async function flushUsers() {
log.info({ users: Object.keys(updates).length }, "Flushed users to Firebase");
}
function getModelFamily(model: string): QuotaModel {
// TODO: add gpt-4-32k models; use key-management/models.ts for family mapping
function getModelFamilyForQuotaUsage(model: string): QuotaModel {
if (model.startsWith("gpt-4")) {
// TODO: add 32k models
return "gpt4";
}
if (model.startsWith("gpt-3.5")) {
@@ -99,7 +99,7 @@ export const transformKoboldPayload: ProxyRequestMiddleware = (
// Kobold doesn't select a model. If the addKey rewriter assigned us a GPT-4
// key, use that. Otherwise, use GPT-3.5-turbo.
const model = req.key!.isGpt4 ? "gpt-4" : "gpt-3.5-turbo";
const model = "gpt-4";
const newBody = {
model,
temperature,
+8 -6
View File
@@ -4,7 +4,7 @@ import * as http from "http";
import util from "util";
import zlib from "zlib";
import { logger } from "../../../logger";
import { keyPool } from "../../../key-management";
import { getOpenAIModelFamily, keyPool } from "../../../key-management";
import { enqueue, trackWaitTime } from "../../queue";
import {
incrementPromptCount,
@@ -297,11 +297,13 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
// TODO: this probably doesn't handle GPT-4-32k variants properly if the
// proxy has keys for both the 8k and 32k context models at the same time.
if (errorPayload.error?.code === "model_not_found") {
if (req.key!.isGpt4) {
errorPayload.proxy_note = `Assigned key isn't provisioned for the GPT-4 snapshot you requested. Try again to get a different key, or use Turbo.`;
} else {
errorPayload.proxy_note = `No model was found for this key.`;
}
const requestedModel = req.body.model;
const modelFamily = getOpenAIModelFamily(requestedModel);
errorPayload.proxy_note = `The key assigned to your prompt does not support the requested model (${requestedModel}, family: ${modelFamily}).`;
req.log.error(
{ key: req.key?.hash, model: requestedModel, modelFamily },
"Prompt was routed to a key that does not support the requested model."
);
}
} else if (req.outboundApi === "anthropic") {
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
+23 -15
View File
@@ -2,7 +2,12 @@ import { RequestHandler, Request, Router } from "express";
import * as http from "http";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../key-management";
import {
ModelFamily,
OpenAIModelFamily,
getOpenAIModelFamily,
keyPool,
} from "../key-management";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
@@ -31,25 +36,33 @@ function getModelsResponse() {
}
// https://platform.openai.com/docs/models/overview
const gptVariants = [
const knownModels = [
"gpt-4",
"gpt-4-0613",
"gpt-4-0314", // EOL 2023-09-13
"gpt-4-0314", // EOL 2024-06-13
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-4-32k-0314", // EOL 2023-09-13
"gpt-4-32k-0314", // EOL 2024-06-13
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301", // EOL 2023-09-13
"gpt-3.5-turbo-0301", // EOL 2024-06-13
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
];
const gpt4Available = keyPool.list().filter((key) => {
return key.service === "openai" && !key.isDisabled && key.isGpt4;
}).length;
let available = new Set<OpenAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "openai") continue;
key.modelFamilies.forEach((family) =>
available.add(family as OpenAIModelFamily)
);
}
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
available = new Set([...available].filter((x) => allowed.has(x)));
const models = gptVariants
console.log(available);
const models = knownModels
.map((id) => ({
id,
object: "model",
@@ -68,12 +81,7 @@ function getModelsResponse() {
root: id,
parent: null,
}))
.filter((model) => {
if (model.id.startsWith("gpt-4")) {
return gpt4Available > 0;
}
return true;
});
.filter((model) => available.has(getOpenAIModelFamily(model.id)));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
+23 -13
View File
@@ -16,13 +16,17 @@
*/
import type { Handler, Request } from "express";
import { keyPool, SupportedModel } from "../key-management";
import {
getClaudeModelFamily,
getOpenAIModelFamily,
keyPool,
ModelFamily,
SupportedModel,
} from "../key-management";
import { logger } from "../logger";
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
import { buildFakeSseMessage } from "./middleware/common";
export type QueuePartition = "claude" | "turbo" | "gpt-4";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
@@ -129,7 +133,7 @@ export function enqueue(req: Request) {
}
}
function getPartitionForRequest(req: Request): QueuePartition {
function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model and API
// provider.
// - claude: requests for the Anthropic API, regardless of model
@@ -138,19 +142,19 @@ function getPartitionForRequest(req: Request): QueuePartition {
const provider = req.outboundApi;
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
if (provider === "anthropic") {
return "claude";
return getClaudeModelFamily(model);
}
if (provider === "openai" && model.startsWith("gpt-4")) {
return "gpt-4";
if (provider === "openai") {
return getOpenAIModelFamily(model);
}
return "turbo";
}
function getQueueForPartition(partition: QueuePartition): Request[] {
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue.filter((req) => getPartitionForRequest(req) === partition);
}
export function dequeue(partition: QueuePartition): Request | undefined {
export function dequeue(partition: ModelFamily): Request | undefined {
const modelQueue = getQueueForPartition(partition);
if (modelQueue.length === 0) {
@@ -189,13 +193,19 @@ function processQueue() {
// This isn't completely correct, because a key can service multiple models.
// Currently if a key is locked out on one model it will also stop servicing
// the others, because we only track one rate limit per key.
// TODO: `getLockoutPeriod` uses model names instead of model families
const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k");
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
const reqs: (Request | undefined)[] = [];
if (gpt432kLockout === 0) {
reqs.push(dequeue("gpt4-32k"));
}
if (gpt4Lockout === 0) {
reqs.push(dequeue("gpt-4"));
reqs.push(dequeue("gpt4"));
}
if (turboLockout === 0) {
reqs.push(dequeue("turbo"));
@@ -244,7 +254,7 @@ export function start() {
log.info(`Started request queue.`);
}
let waitTimes: { partition: QueuePartition; start: number; end: number }[] = [];
let waitTimes: { partition: ModelFamily; start: number; end: number }[] = [];
/** Adds a successful request to the list of wait times. */
export function trackWaitTime(req: Request) {
@@ -256,7 +266,7 @@ export function trackWaitTime(req: Request) {
}
/** Returns average wait time in milliseconds. */
export function getEstimatedWaitTime(partition: QueuePartition) {
export function getEstimatedWaitTime(partition: ModelFamily) {
const now = Date.now();
const recentWaits = waitTimes.filter(
(wt) => wt.partition === partition && now - wt.end < 300 * 1000
@@ -271,7 +281,7 @@ export function getEstimatedWaitTime(partition: QueuePartition) {
);
}
export function getQueueLength(partition: QueuePartition | "all" = "all") {
export function getQueueLength(partition: ModelFamily | "all" = "all") {
if (partition === "all") {
return queue.length;
}