Add GPT-4-32k support (khanon/oai-reverse-proxy!39)
This commit is contained in:
+1
-1
@@ -10,7 +10,7 @@
|
||||
# REJECT_DISALLOWED=false
|
||||
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
|
||||
# CHECK_KEYS=true
|
||||
# TURBO_ONLY=false
|
||||
# ALLOWED_MODEL_FAMILIES=claude,turbo,gpt4,gpt4-32k
|
||||
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
||||
# BLOCK_MESSAGE="You must be over the age of majority in your country to use this service."
|
||||
# BLOCK_REDIRECT="https://roblox.com/"
|
||||
|
||||
+25
-7
@@ -1,6 +1,7 @@
|
||||
import dotenv from "dotenv";
|
||||
import type firebase from "firebase-admin";
|
||||
import pino from "pino";
|
||||
import type { ModelFamily } from "./key-management/models";
|
||||
dotenv.config();
|
||||
|
||||
// Can't import the usual logger here because it itself needs the config.
|
||||
@@ -112,11 +113,8 @@ type Config = {
|
||||
* Desination URL to redirect blocked requests to, for non-JSON requests.
|
||||
*/
|
||||
blockRedirect?: string;
|
||||
/**
|
||||
* Whether the proxy should disallow requests for GPT-4 models in order to
|
||||
* prevent excessive spend. Applies only to OpenAI.
|
||||
*/
|
||||
turboOnly?: boolean;
|
||||
/** Which model families to allow requests for. Applies only to OpenAI. */
|
||||
allowedModelFamilies: ModelFamily[];
|
||||
/**
|
||||
* The number of (LLM) tokens a user can consume before requests are rejected.
|
||||
* Limits include both prompt and response tokens. `special` users are exempt.
|
||||
@@ -170,6 +168,12 @@ export const config: Config = {
|
||||
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
|
||||
400
|
||||
),
|
||||
allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [
|
||||
"turbo",
|
||||
"gpt4",
|
||||
"gpt4-32k",
|
||||
"claude",
|
||||
]),
|
||||
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
|
||||
rejectMessage: getEnvWithDefault(
|
||||
"REJECT_MESSAGE",
|
||||
@@ -190,7 +194,6 @@ export const config: Config = {
|
||||
"You must be over the age of majority in your country to use this service."
|
||||
),
|
||||
blockRedirect: getEnvWithDefault("BLOCK_REDIRECT", "https://www.9gag.com"),
|
||||
turboOnly: getEnvWithDefault("TURBO_ONLY", false),
|
||||
tokenQuota: {
|
||||
turbo: getEnvWithDefault("TOKEN_QUOTA_TURBO", 0),
|
||||
gpt4: getEnvWithDefault("TOKEN_QUOTA_GPT4", 0),
|
||||
@@ -200,6 +203,15 @@ export const config: Config = {
|
||||
} as const;
|
||||
|
||||
export async function assertConfigIsValid() {
|
||||
if (process.env.TURBO_ONLY === "true") {
|
||||
startupLogger.warn(
|
||||
"TURBO_ONLY is deprecated. Use ALLOWED_MODEL_FAMILIES=turbo instead."
|
||||
);
|
||||
config.allowedModelFamilies = config.allowedModelFamilies.filter(
|
||||
(f) => !f.includes("gpt4")
|
||||
);
|
||||
}
|
||||
|
||||
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
|
||||
throw new Error(
|
||||
`Invalid gatekeeper mode: ${config.gatekeeper}. Must be one of: none, proxy_key, user_token.`
|
||||
@@ -298,7 +310,7 @@ export function listConfig(obj: Config = config): Record<string, any> {
|
||||
result[key] = value;
|
||||
}
|
||||
|
||||
if (typeof obj[key] === "object") {
|
||||
if (typeof obj[key] === "object" && !Array.isArray(obj[key])) {
|
||||
result[key] = listConfig(obj[key] as unknown as Config);
|
||||
}
|
||||
}
|
||||
@@ -320,6 +332,12 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
||||
if (env === "OPENAI_KEY" || env === "ANTHROPIC_KEY") {
|
||||
return value as unknown as T;
|
||||
}
|
||||
|
||||
// Intended to be used for comma-delimited lists
|
||||
if (Array.isArray(defaultValue)) {
|
||||
return value.split(",").map((v) => v.trim()) as T;
|
||||
}
|
||||
|
||||
return JSON.parse(value) as T;
|
||||
} catch (err) {
|
||||
return value as unknown as T;
|
||||
|
||||
+86
-24
@@ -2,13 +2,14 @@ import fs from "fs";
|
||||
import { Request, Response } from "express";
|
||||
import showdown from "showdown";
|
||||
import { config, listConfig } from "./config";
|
||||
import { OpenAIKey, keyPool } from "./key-management";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import {
|
||||
QueuePartition,
|
||||
getEstimatedWaitTime,
|
||||
getQueueLength,
|
||||
} from "./proxy/queue";
|
||||
ModelFamily,
|
||||
OpenAIKey,
|
||||
OpenAIModelFamily,
|
||||
keyPool,
|
||||
} from "./key-management";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||
|
||||
const INFO_PAGE_TTL = 5000;
|
||||
let infoPageHtml: string | undefined;
|
||||
@@ -78,34 +79,65 @@ function cacheInfoPageHtml(baseUrl: string) {
|
||||
type ServiceInfo = {
|
||||
activeKeys: number;
|
||||
trialKeys?: number;
|
||||
// activeLimit: string;
|
||||
revokedKeys?: number;
|
||||
overQuotaKeys?: number;
|
||||
proomptersInQueue: number;
|
||||
estimatedQueueTime: string;
|
||||
};
|
||||
|
||||
// this has long since outgrown this awful "dump everything in a <pre> tag" approach
|
||||
// but I really don't want to spend time on a proper UI for this right now
|
||||
|
||||
/**
|
||||
* This may end up doing a very very large number of wasted iterations over
|
||||
* potentially large key lists, don't call it too often.
|
||||
*/
|
||||
function getOpenAIInfo() {
|
||||
const info: { [model: string]: Partial<ServiceInfo> } = {};
|
||||
const keys = keyPool
|
||||
.list()
|
||||
.filter((k) => k.service === "openai") as OpenAIKey[];
|
||||
const hasGpt4 = keys.some((k) => k.isGpt4) && !config.turboOnly;
|
||||
const allowed = new Set(config.allowedModelFamilies);
|
||||
let available = new Set<OpenAIModelFamily>();
|
||||
|
||||
const keys = keyPool.list().filter((k) => {
|
||||
if (k.service === "openai") {
|
||||
k.modelFamilies.forEach((f) => available.add(f as OpenAIModelFamily));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}) as Omit<OpenAIKey, "key">[];
|
||||
|
||||
available = new Set([...available].filter((f) => allowed.has(f)));
|
||||
|
||||
if (keyPool.anyUnchecked()) {
|
||||
const uncheckedKeys = keys.filter((k) => !k.lastChecked);
|
||||
info.status =
|
||||
`Performing startup key checks (${uncheckedKeys.length} left).` as any;
|
||||
`Performing key checks (${uncheckedKeys.length} left).` as any;
|
||||
} else {
|
||||
delete info.status;
|
||||
}
|
||||
|
||||
if (config.checkKeys) {
|
||||
const turboKeys = keys.filter((k) => !k.isGpt4);
|
||||
const gpt4Keys = keys.filter((k) => k.isGpt4);
|
||||
const keysByModel = keys.reduce(
|
||||
(acc, k) => {
|
||||
// only put keys in the most important family they belong to.
|
||||
// if a model family is disabled, key will be in the next most
|
||||
// important family.
|
||||
if (k.modelFamilies.includes("gpt4-32k") && allowed.has("gpt4-32k")) {
|
||||
acc["gpt4-32k"].push(k);
|
||||
} else if (k.modelFamilies.includes("gpt4") && allowed.has("gpt4")) {
|
||||
acc["gpt4"].push(k);
|
||||
} else {
|
||||
acc["turbo"].push(k);
|
||||
}
|
||||
return acc;
|
||||
},
|
||||
{ turbo: [], gpt4: [], "gpt4-32k": [] } as Record<
|
||||
OpenAIModelFamily,
|
||||
Omit<OpenAIKey, "key">[]
|
||||
>
|
||||
);
|
||||
|
||||
const turboKeys = keysByModel["turbo"];
|
||||
const gpt4Keys = keysByModel["gpt4"];
|
||||
const gpt432kKeys = keysByModel["gpt4-32k"];
|
||||
|
||||
// this is fucked
|
||||
|
||||
info.turbo = {
|
||||
activeKeys: turboKeys.filter((k) => !k.isDisabled).length,
|
||||
@@ -114,7 +146,7 @@ function getOpenAIInfo() {
|
||||
overQuotaKeys: turboKeys.filter((k) => k.isOverQuota).length,
|
||||
};
|
||||
|
||||
if (hasGpt4) {
|
||||
if (available.has("gpt4")) {
|
||||
info.gpt4 = {
|
||||
activeKeys: gpt4Keys.filter((k) => !k.isDisabled).length,
|
||||
trialKeys: gpt4Keys.filter((k) => k.isTrial).length,
|
||||
@@ -122,11 +154,22 @@ function getOpenAIInfo() {
|
||||
overQuotaKeys: gpt4Keys.filter((k) => k.isOverQuota).length,
|
||||
};
|
||||
}
|
||||
|
||||
if (available.has("gpt4-32k")) {
|
||||
info["gpt4-32k"] = {
|
||||
activeKeys: gpt432kKeys.filter((k) => !k.isDisabled).length,
|
||||
trialKeys: gpt432kKeys.filter((k) => k.isTrial).length,
|
||||
revokedKeys: gpt432kKeys.filter((k) => k.isRevoked).length,
|
||||
overQuotaKeys: gpt432kKeys.filter((k) => k.isOverQuota).length,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
info.status = "Key checking is disabled." as any;
|
||||
info.turbo = { activeKeys: keys.filter((k) => !k.isDisabled).length };
|
||||
info.gpt4 = {
|
||||
activeKeys: keys.filter((k) => !k.isDisabled && k.isGpt4).length,
|
||||
activeKeys: keys.filter(
|
||||
(k) => !k.isDisabled && k.modelFamilies.includes("gpt4")
|
||||
).length,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -135,12 +178,18 @@ function getOpenAIInfo() {
|
||||
info.turbo.proomptersInQueue = turboQueue.proomptersInQueue;
|
||||
info.turbo.estimatedQueueTime = turboQueue.estimatedQueueTime;
|
||||
|
||||
if (hasGpt4) {
|
||||
const gpt4Queue = getQueueInformation("gpt-4");
|
||||
if (available.has("gpt4")) {
|
||||
const gpt4Queue = getQueueInformation("gpt4");
|
||||
info.gpt4.proomptersInQueue = gpt4Queue.proomptersInQueue;
|
||||
info.gpt4.estimatedQueueTime = gpt4Queue.estimatedQueueTime;
|
||||
}
|
||||
|
||||
if (available.has("gpt4-32k")) {
|
||||
const gpt432kQueue = getQueueInformation("gpt4-32k");
|
||||
info["gpt4-32k"].proomptersInQueue = gpt432kQueue.proomptersInQueue;
|
||||
info["gpt4-32k"].estimatedQueueTime = gpt432kQueue.estimatedQueueTime;
|
||||
}
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
@@ -180,12 +229,25 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t
|
||||
infoBody += `\n## Estimated Wait Times\nIf the AI is busy, your prompt will processed when a slot frees up.`;
|
||||
|
||||
if (config.openaiKey) {
|
||||
// this is also fucked
|
||||
const keys = keyPool.list().filter((k) => k.service === "openai");
|
||||
|
||||
const turboWait = getQueueInformation("turbo").estimatedQueueTime;
|
||||
const gpt4Wait = getQueueInformation("gpt-4").estimatedQueueTime;
|
||||
waits.push(`**Turbo:** ${turboWait}`);
|
||||
if (keyPool.list().some((k) => k.isGpt4) && !config.turboOnly) {
|
||||
|
||||
const gpt4Wait = getQueueInformation("gpt4").estimatedQueueTime;
|
||||
const hasGpt4 = keys.some((k) => k.modelFamilies.includes("gpt4"));
|
||||
const allowedGpt4 = config.allowedModelFamilies.includes("gpt4");
|
||||
if (hasGpt4 && allowedGpt4) {
|
||||
waits.push(`**GPT-4:** ${gpt4Wait}`);
|
||||
}
|
||||
|
||||
const gpt432kWait = getQueueInformation("gpt4-32k").estimatedQueueTime;
|
||||
const hasGpt432k = keys.some((k) => k.modelFamilies.includes("gpt4-32k"));
|
||||
const allowedGpt432k = config.allowedModelFamilies.includes("gpt4-32k");
|
||||
if (hasGpt432k && allowedGpt432k) {
|
||||
waits.push(`**GPT-4-32k:** ${gpt432kWait}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (config.anthropicKey) {
|
||||
@@ -202,7 +264,7 @@ ${customGreeting}`;
|
||||
}
|
||||
|
||||
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
|
||||
function getQueueInformation(partition: QueuePartition) {
|
||||
function getQueueInformation(partition: ModelFamily) {
|
||||
const waitMs = getEstimatedWaitTime(partition);
|
||||
const waitTime =
|
||||
waitMs < 60000
|
||||
|
||||
@@ -2,6 +2,7 @@ import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
import type { AnthropicModelFamily } from "../models";
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/selecting-a-model
|
||||
export const ANTHROPIC_SUPPORTED_MODELS = [
|
||||
@@ -25,6 +26,7 @@ export type AnthropicKeyUpdate = Omit<
|
||||
|
||||
export interface AnthropicKey extends Key {
|
||||
readonly service: "anthropic";
|
||||
readonly modelFamilies: AnthropicModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
@@ -71,7 +73,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
const newKey: AnthropicKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
isGpt4: false,
|
||||
modelFamilies: ["claude"],
|
||||
isTrial: false,
|
||||
isDisabled: false,
|
||||
promptCount: 0,
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
AnthropicModel,
|
||||
} from "./anthropic/provider";
|
||||
import { KeyPool } from "./key-pool";
|
||||
import type { ModelFamily } from "./models";
|
||||
|
||||
export type AIService = "openai" | "anthropic";
|
||||
export type Model = OpenAIModel | AnthropicModel;
|
||||
@@ -15,8 +16,8 @@ export interface Key {
|
||||
service: AIService;
|
||||
/** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */
|
||||
isTrial: boolean;
|
||||
/** Whether this key has been provisioned for GPT-4. */
|
||||
isGpt4: boolean;
|
||||
/** The model families that this key has access to. */
|
||||
modelFamilies: ModelFamily[];
|
||||
/** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */
|
||||
isDisabled: boolean;
|
||||
/** The number of prompts that have been sent with this key. */
|
||||
@@ -65,3 +66,9 @@ export type SupportedModel = (typeof SUPPORTED_MODELS)[number];
|
||||
export { OPENAI_SUPPORTED_MODELS, ANTHROPIC_SUPPORTED_MODELS };
|
||||
export { AnthropicKey } from "./anthropic/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export type {
|
||||
OpenAIModelFamily,
|
||||
AnthropicModelFamily,
|
||||
ModelFamily,
|
||||
} from "./models";
|
||||
export { getOpenAIModelFamily, getClaudeModelFamily } from "./models";
|
||||
|
||||
@@ -3,6 +3,7 @@ import schedule from "node-schedule";
|
||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||
import { Key, Model, KeyProvider, AIService } from "./index";
|
||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||
@@ -89,6 +90,11 @@ export class KeyPool {
|
||||
}
|
||||
|
||||
public recheck(service: AIService): void {
|
||||
if (!config.checkKeys) {
|
||||
logger.info("Skipping key recheck because key checking is disabled");
|
||||
return;
|
||||
}
|
||||
|
||||
const provider = this.getKeyProvider(service);
|
||||
if (provider instanceof OpenAIKeyProvider) {
|
||||
provider.recheck();
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { logger } from "../logger";
|
||||
|
||||
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type ModelFamily = OpenAIModelFamily | AnthropicModelFamily;
|
||||
export type ModelFamilyMap = { [regex: string]: ModelFamily };
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
|
||||
"^gpt-4-32k$": "gpt4-32k",
|
||||
"^gpt-4-\\d{4}$": "gpt4",
|
||||
"^gpt-4$": "gpt4",
|
||||
"^gpt-3.5-turbo": "turbo",
|
||||
};
|
||||
|
||||
export function getOpenAIModelFamily(model: string): OpenAIModelFamily {
|
||||
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
|
||||
if (model.match(regex)) return family;
|
||||
}
|
||||
const stack = new Error().stack;
|
||||
logger.warn({ model, stack }, "Unmapped model family");
|
||||
return "gpt4";
|
||||
}
|
||||
|
||||
export function getClaudeModelFamily(_model: string): ModelFamily {
|
||||
return "claude";
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import axios, { AxiosError } from "axios";
|
||||
import { logger } from "../../logger";
|
||||
import type { OpenAIKey, OpenAIKeyProvider } from "./provider";
|
||||
import type { OpenAIModelFamily } from "../models";
|
||||
|
||||
/** Minimum time in between any two key checks. */
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
@@ -136,7 +137,7 @@ export class OpenAIKeyChecker {
|
||||
this.maybeCreateOrganizationClones(key),
|
||||
]);
|
||||
const updates = {
|
||||
isGpt4: provisionedModels.gpt4,
|
||||
modelFamilies: provisionedModels,
|
||||
isTrial: livenessTest.rateLimit <= 250,
|
||||
softLimit: 0,
|
||||
hardLimit: 0,
|
||||
@@ -149,7 +150,10 @@ export class OpenAIKeyChecker {
|
||||
const updates = { softLimit: 0, hardLimit: 0, systemHardLimit: 0 };
|
||||
this.updateKey(key.hash, updates);
|
||||
}
|
||||
this.log.info({ key: key.hash }, "Key check complete.");
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies },
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
// touch the key so we don't check it again for a while
|
||||
this.updateKey(key.hash, {});
|
||||
@@ -166,23 +170,35 @@ export class OpenAIKeyChecker {
|
||||
|
||||
private async getProvisionedModels(
|
||||
key: OpenAIKey
|
||||
): Promise<{ turbo: boolean; gpt4: boolean }> {
|
||||
): Promise<OpenAIModelFamily[]> {
|
||||
const opts = { headers: OpenAIKeyChecker.getHeaders(key) };
|
||||
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
|
||||
const models = data.data;
|
||||
const turbo = models.some(({ id }) => id.startsWith("gpt-3.5"));
|
||||
const gpt4 = models.some(({ id }) => id.startsWith("gpt-4"));
|
||||
// We want to update the key's `isGpt4` flag here, but we don't want to
|
||||
|
||||
const families: OpenAIModelFamily[] = [];
|
||||
if (models.some(({ id }) => id.startsWith("gpt-3.5-turbo"))) {
|
||||
families.push("turbo");
|
||||
}
|
||||
|
||||
if (models.some(({ id }) => id.startsWith("gpt-4"))) {
|
||||
families.push("gpt4");
|
||||
}
|
||||
|
||||
if (models.some(({ id }) => id.startsWith("gpt-4-32k"))) {
|
||||
families.push("gpt4-32k");
|
||||
}
|
||||
|
||||
// We want to update the key's model families here, but we don't want to
|
||||
// update its `lastChecked` timestamp because we need to let the liveness
|
||||
// check run before we can consider the key checked.
|
||||
|
||||
// Need to use `find` here because keys are cloned from the pool.
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
|
||||
this.updateKey(key.hash, {
|
||||
isGpt4: gpt4,
|
||||
modelFamilies: families,
|
||||
lastChecked: keyFromPool.lastChecked,
|
||||
});
|
||||
return { turbo, gpt4 };
|
||||
return families;
|
||||
}
|
||||
|
||||
private async maybeCreateOrganizationClones(key: OpenAIKey) {
|
||||
@@ -219,7 +235,7 @@ export class OpenAIKeyChecker {
|
||||
this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
isGpt4: false,
|
||||
modelFamilies: ["turbo"],
|
||||
});
|
||||
} else if (status === 429) {
|
||||
switch (data.error.type) {
|
||||
@@ -228,7 +244,9 @@ export class OpenAIKeyChecker {
|
||||
case "billing_not_active":
|
||||
const isOverQuota = data.error.type === "insufficient_quota";
|
||||
const isRevoked = !isOverQuota;
|
||||
const isGpt4 = isRevoked ? false : key.isGpt4;
|
||||
const modelFamilies: OpenAIModelFamily[] = isRevoked
|
||||
? ["turbo"]
|
||||
: key.modelFamilies;
|
||||
this.log.warn(
|
||||
{ key: key.hash, rateLimitType: data.error.type, error: data },
|
||||
"Key returned a non-transient 429 error. Disabling key."
|
||||
@@ -237,7 +255,7 @@ export class OpenAIKeyChecker {
|
||||
isDisabled: true,
|
||||
isRevoked,
|
||||
isOverQuota,
|
||||
isGpt4,
|
||||
modelFamilies,
|
||||
});
|
||||
break;
|
||||
case "requests":
|
||||
|
||||
@@ -9,8 +9,9 @@ import { KeyProvider, Key, Model } from "../index";
|
||||
import { config } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
import { OpenAIKeyChecker } from "./checker";
|
||||
import { OpenAIModelFamily, getOpenAIModelFamily } from "../models";
|
||||
|
||||
export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4";
|
||||
export type OpenAIModel = "gpt-3.5-turbo" | "gpt-4" | "gpt-4-32k";
|
||||
export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
@@ -18,6 +19,7 @@ export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
|
||||
|
||||
export interface OpenAIKey extends Key {
|
||||
readonly service: "openai";
|
||||
modelFamilies: OpenAIModelFamily[];
|
||||
/**
|
||||
* Some keys are assigned to multiple organizations, each with their own quota
|
||||
* limits. We clone the key for each organization and track usage/disabled
|
||||
@@ -85,7 +87,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
const newKey = {
|
||||
key: k,
|
||||
service: "openai" as const,
|
||||
isGpt4: true,
|
||||
modelFamilies: ["turbo" as const, "gpt4" as const],
|
||||
isTrial: false,
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
@@ -134,35 +136,39 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
}
|
||||
|
||||
public get(model: Model) {
|
||||
const needGpt4 = model.startsWith("gpt-4");
|
||||
const neededFamily = getOpenAIModelFamily(model);
|
||||
const availableKeys = this.keys.filter(
|
||||
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
|
||||
(key) => !key.isDisabled && key.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
|
||||
if (availableKeys.length === 0) {
|
||||
let message = needGpt4
|
||||
? "No GPT-4 keys available. Try selecting a Turbo model."
|
||||
: "No active OpenAI keys available.";
|
||||
throw new Error(message);
|
||||
throw new Error(`No active keys available for ${neededFamily} models.`);
|
||||
}
|
||||
|
||||
if (needGpt4 && config.turboOnly) {
|
||||
if (!config.allowedModelFamilies.includes(neededFamily)) {
|
||||
throw new Error(
|
||||
"Proxy operator has disabled GPT-4 to reduce quota usage. Try selecting a Turbo model."
|
||||
`Proxy operator has disabled access to ${neededFamily} models.`
|
||||
);
|
||||
}
|
||||
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. We ignore rate limits from over a minute ago
|
||||
// a. We ignore rate limits from >30 seconds ago
|
||||
// b. If all keys were rate limited in the last minute, select the
|
||||
// least recently rate limited key
|
||||
// 2. Keys which are trials
|
||||
// 3. Keys which have not been used in the longest time
|
||||
// 3. Keys which do *not* have access to GPT-4-32k
|
||||
// 4. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitThreshold = 60 * 1000;
|
||||
const rateLimitThreshold = 30 * 1000;
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
// TODO: this isn't quite right; keys are briefly artificially rate-
|
||||
// limited when they are selected, so this will deprioritize keys that
|
||||
// may not actually be limited, simply because they were used recently.
|
||||
// This should be adjusted to use a new `rateLimitedUntil` field instead
|
||||
// of `rateLimitedAt`.
|
||||
const aRateLimited = now - a.rateLimitedAt < rateLimitThreshold;
|
||||
const bRateLimited = now - b.rateLimitedAt < rateLimitThreshold;
|
||||
|
||||
@@ -171,13 +177,32 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
// Neither key is rate limited, continue
|
||||
|
||||
if (a.isTrial && !b.isTrial) return -1;
|
||||
if (!a.isTrial && b.isTrial) return 1;
|
||||
// Neither or both keys are trials, continue
|
||||
|
||||
const aHas32k = a.modelFamilies.includes("gpt4-32k");
|
||||
const bHas32k = b.modelFamilies.includes("gpt4-32k");
|
||||
if (aHas32k && !bHas32k) return 1;
|
||||
if (!aHas32k && bHas32k) return -1;
|
||||
// Neither or both keys have 32k, continue
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
// logger.debug(
|
||||
// {
|
||||
// byPriority: keysByPriority.map((k) => ({
|
||||
// hash: k.hash,
|
||||
// isRateLimited: now - k.rateLimitedAt < rateLimitThreshold,
|
||||
// modelFamilies: k.modelFamilies,
|
||||
// })),
|
||||
// },
|
||||
// "Keys sorted by priority"
|
||||
// );
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
|
||||
@@ -243,9 +268,9 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
* the request, or returns 0 if a key is ready immediately.
|
||||
*/
|
||||
public getLockoutPeriod(model: Model = "gpt-4"): number {
|
||||
const needGpt4 = model.startsWith("gpt-4");
|
||||
const neededFamily = getOpenAIModelFamily(model);
|
||||
const activeKeys = this.keys.filter(
|
||||
(key) => !key.isDisabled && (!needGpt4 || key.isGpt4)
|
||||
(key) => !key.isDisabled && key.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
|
||||
if (activeKeys.length === 0) {
|
||||
|
||||
@@ -160,7 +160,7 @@ export function incrementTokenCount(
|
||||
) {
|
||||
const user = users.get(token);
|
||||
if (!user) return;
|
||||
const modelFamily = getModelFamily(model);
|
||||
const modelFamily = getModelFamilyForQuotaUsage(model);
|
||||
user.tokenCounts[modelFamily] += consumption;
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
@@ -197,7 +197,7 @@ export function hasAvailableQuota(
|
||||
if (!user) return false;
|
||||
if (user.type === "special") return true;
|
||||
|
||||
const modelFamily = getModelFamily(model);
|
||||
const modelFamily = getModelFamilyForQuotaUsage(model);
|
||||
const { tokenCounts, tokenLimits } = user;
|
||||
const tokenLimit = tokenLimits[modelFamily];
|
||||
|
||||
@@ -281,9 +281,9 @@ async function flushUsers() {
|
||||
log.info({ users: Object.keys(updates).length }, "Flushed users to Firebase");
|
||||
}
|
||||
|
||||
function getModelFamily(model: string): QuotaModel {
|
||||
// TODO: add gpt-4-32k models; use key-management/models.ts for family mapping
|
||||
function getModelFamilyForQuotaUsage(model: string): QuotaModel {
|
||||
if (model.startsWith("gpt-4")) {
|
||||
// TODO: add 32k models
|
||||
return "gpt4";
|
||||
}
|
||||
if (model.startsWith("gpt-3.5")) {
|
||||
|
||||
@@ -99,7 +99,7 @@ export const transformKoboldPayload: ProxyRequestMiddleware = (
|
||||
// Kobold doesn't select a model. If the addKey rewriter assigned us a GPT-4
|
||||
// key, use that. Otherwise, use GPT-3.5-turbo.
|
||||
|
||||
const model = req.key!.isGpt4 ? "gpt-4" : "gpt-3.5-turbo";
|
||||
const model = "gpt-4";
|
||||
const newBody = {
|
||||
model,
|
||||
temperature,
|
||||
|
||||
@@ -4,7 +4,7 @@ import * as http from "http";
|
||||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import { logger } from "../../../logger";
|
||||
import { keyPool } from "../../../key-management";
|
||||
import { getOpenAIModelFamily, keyPool } from "../../../key-management";
|
||||
import { enqueue, trackWaitTime } from "../../queue";
|
||||
import {
|
||||
incrementPromptCount,
|
||||
@@ -297,11 +297,13 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
// TODO: this probably doesn't handle GPT-4-32k variants properly if the
|
||||
// proxy has keys for both the 8k and 32k context models at the same time.
|
||||
if (errorPayload.error?.code === "model_not_found") {
|
||||
if (req.key!.isGpt4) {
|
||||
errorPayload.proxy_note = `Assigned key isn't provisioned for the GPT-4 snapshot you requested. Try again to get a different key, or use Turbo.`;
|
||||
} else {
|
||||
errorPayload.proxy_note = `No model was found for this key.`;
|
||||
}
|
||||
const requestedModel = req.body.model;
|
||||
const modelFamily = getOpenAIModelFamily(requestedModel);
|
||||
errorPayload.proxy_note = `The key assigned to your prompt does not support the requested model (${requestedModel}, family: ${modelFamily}).`;
|
||||
req.log.error(
|
||||
{ key: req.key?.hash, model: requestedModel, modelFamily },
|
||||
"Prompt was routed to a key that does not support the requested model."
|
||||
);
|
||||
}
|
||||
} else if (req.outboundApi === "anthropic") {
|
||||
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
|
||||
|
||||
+23
-15
@@ -2,7 +2,12 @@ import { RequestHandler, Request, Router } from "express";
|
||||
import * as http from "http";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../key-management";
|
||||
import {
|
||||
ModelFamily,
|
||||
OpenAIModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
keyPool,
|
||||
} from "../key-management";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
@@ -31,25 +36,33 @@ function getModelsResponse() {
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
const gptVariants = [
|
||||
const knownModels = [
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314", // EOL 2023-09-13
|
||||
"gpt-4-0314", // EOL 2024-06-13
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-32k-0314", // EOL 2023-09-13
|
||||
"gpt-4-32k-0314", // EOL 2024-06-13
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0301", // EOL 2023-09-13
|
||||
"gpt-3.5-turbo-0301", // EOL 2024-06-13
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
];
|
||||
|
||||
const gpt4Available = keyPool.list().filter((key) => {
|
||||
return key.service === "openai" && !key.isDisabled && key.isGpt4;
|
||||
}).length;
|
||||
let available = new Set<OpenAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "openai") continue;
|
||||
key.modelFamilies.forEach((family) =>
|
||||
available.add(family as OpenAIModelFamily)
|
||||
);
|
||||
}
|
||||
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||
|
||||
const models = gptVariants
|
||||
console.log(available);
|
||||
|
||||
const models = knownModels
|
||||
.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
@@ -68,12 +81,7 @@ function getModelsResponse() {
|
||||
root: id,
|
||||
parent: null,
|
||||
}))
|
||||
.filter((model) => {
|
||||
if (model.id.startsWith("gpt-4")) {
|
||||
return gpt4Available > 0;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
.filter((model) => available.has(getOpenAIModelFamily(model.id)));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
+23
-13
@@ -16,13 +16,17 @@
|
||||
*/
|
||||
|
||||
import type { Handler, Request } from "express";
|
||||
import { keyPool, SupportedModel } from "../key-management";
|
||||
import {
|
||||
getClaudeModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
keyPool,
|
||||
ModelFamily,
|
||||
SupportedModel,
|
||||
} from "../key-management";
|
||||
import { logger } from "../logger";
|
||||
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
|
||||
import { buildFakeSseMessage } from "./middleware/common";
|
||||
|
||||
export type QueuePartition = "claude" | "turbo" | "gpt-4";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
|
||||
@@ -129,7 +133,7 @@ export function enqueue(req: Request) {
|
||||
}
|
||||
}
|
||||
|
||||
function getPartitionForRequest(req: Request): QueuePartition {
|
||||
function getPartitionForRequest(req: Request): ModelFamily {
|
||||
// There is a single request queue, but it is partitioned by model and API
|
||||
// provider.
|
||||
// - claude: requests for the Anthropic API, regardless of model
|
||||
@@ -138,19 +142,19 @@ function getPartitionForRequest(req: Request): QueuePartition {
|
||||
const provider = req.outboundApi;
|
||||
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
|
||||
if (provider === "anthropic") {
|
||||
return "claude";
|
||||
return getClaudeModelFamily(model);
|
||||
}
|
||||
if (provider === "openai" && model.startsWith("gpt-4")) {
|
||||
return "gpt-4";
|
||||
if (provider === "openai") {
|
||||
return getOpenAIModelFamily(model);
|
||||
}
|
||||
return "turbo";
|
||||
}
|
||||
|
||||
function getQueueForPartition(partition: QueuePartition): Request[] {
|
||||
function getQueueForPartition(partition: ModelFamily): Request[] {
|
||||
return queue.filter((req) => getPartitionForRequest(req) === partition);
|
||||
}
|
||||
|
||||
export function dequeue(partition: QueuePartition): Request | undefined {
|
||||
export function dequeue(partition: ModelFamily): Request | undefined {
|
||||
const modelQueue = getQueueForPartition(partition);
|
||||
|
||||
if (modelQueue.length === 0) {
|
||||
@@ -189,13 +193,19 @@ function processQueue() {
|
||||
// This isn't completely correct, because a key can service multiple models.
|
||||
// Currently if a key is locked out on one model it will also stop servicing
|
||||
// the others, because we only track one rate limit per key.
|
||||
|
||||
// TODO: `getLockoutPeriod` uses model names instead of model families
|
||||
const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k");
|
||||
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
|
||||
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
|
||||
const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
|
||||
|
||||
const reqs: (Request | undefined)[] = [];
|
||||
if (gpt432kLockout === 0) {
|
||||
reqs.push(dequeue("gpt4-32k"));
|
||||
}
|
||||
if (gpt4Lockout === 0) {
|
||||
reqs.push(dequeue("gpt-4"));
|
||||
reqs.push(dequeue("gpt4"));
|
||||
}
|
||||
if (turboLockout === 0) {
|
||||
reqs.push(dequeue("turbo"));
|
||||
@@ -244,7 +254,7 @@ export function start() {
|
||||
log.info(`Started request queue.`);
|
||||
}
|
||||
|
||||
let waitTimes: { partition: QueuePartition; start: number; end: number }[] = [];
|
||||
let waitTimes: { partition: ModelFamily; start: number; end: number }[] = [];
|
||||
|
||||
/** Adds a successful request to the list of wait times. */
|
||||
export function trackWaitTime(req: Request) {
|
||||
@@ -256,7 +266,7 @@ export function trackWaitTime(req: Request) {
|
||||
}
|
||||
|
||||
/** Returns average wait time in milliseconds. */
|
||||
export function getEstimatedWaitTime(partition: QueuePartition) {
|
||||
export function getEstimatedWaitTime(partition: ModelFamily) {
|
||||
const now = Date.now();
|
||||
const recentWaits = waitTimes.filter(
|
||||
(wt) => wt.partition === partition && now - wt.end < 300 * 1000
|
||||
@@ -271,7 +281,7 @@ export function getEstimatedWaitTime(partition: QueuePartition) {
|
||||
);
|
||||
}
|
||||
|
||||
export function getQueueLength(partition: QueuePartition | "all" = "all") {
|
||||
export function getQueueLength(partition: ModelFamily | "all" = "all") {
|
||||
if (partition === "all") {
|
||||
return queue.length;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user