This commit is contained in:
khanon
2023-08-29 22:56:54 +00:00
parent 3c56103de0
commit 4d781e1720
14 changed files with 285 additions and 100 deletions
+23 -13
View File
@@ -16,13 +16,17 @@
*/
import type { Handler, Request } from "express";
import { keyPool, SupportedModel } from "../key-management";
import {
getClaudeModelFamily,
getOpenAIModelFamily,
keyPool,
ModelFamily,
SupportedModel,
} from "../key-management";
import { logger } from "../logger";
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
import { buildFakeSseMessage } from "./middleware/common";
export type QueuePartition = "claude" | "turbo" | "gpt-4";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
@@ -129,7 +133,7 @@ export function enqueue(req: Request) {
}
}
function getPartitionForRequest(req: Request): QueuePartition {
function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model and API
// provider.
// - claude: requests for the Anthropic API, regardless of model
@@ -138,19 +142,19 @@ function getPartitionForRequest(req: Request): QueuePartition {
const provider = req.outboundApi;
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
if (provider === "anthropic") {
return "claude";
return getClaudeModelFamily(model);
}
if (provider === "openai" && model.startsWith("gpt-4")) {
return "gpt-4";
if (provider === "openai") {
return getOpenAIModelFamily(model);
}
return "turbo";
}
function getQueueForPartition(partition: QueuePartition): Request[] {
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue.filter((req) => getPartitionForRequest(req) === partition);
}
export function dequeue(partition: QueuePartition): Request | undefined {
export function dequeue(partition: ModelFamily): Request | undefined {
const modelQueue = getQueueForPartition(partition);
if (modelQueue.length === 0) {
@@ -189,13 +193,19 @@ function processQueue() {
// This isn't completely correct, because a key can service multiple models.
// Currently if a key is locked out on one model it will also stop servicing
// the others, because we only track one rate limit per key.
// TODO: `getLockoutPeriod` uses model names instead of model families
const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k");
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
const reqs: (Request | undefined)[] = [];
if (gpt432kLockout === 0) {
reqs.push(dequeue("gpt4-32k"));
}
if (gpt4Lockout === 0) {
reqs.push(dequeue("gpt-4"));
reqs.push(dequeue("gpt4"));
}
if (turboLockout === 0) {
reqs.push(dequeue("turbo"));
@@ -244,7 +254,7 @@ export function start() {
log.info(`Started request queue.`);
}
let waitTimes: { partition: QueuePartition; start: number; end: number }[] = [];
let waitTimes: { partition: ModelFamily; start: number; end: number }[] = [];
/** Adds a successful request to the list of wait times. */
export function trackWaitTime(req: Request) {
@@ -256,7 +266,7 @@ export function trackWaitTime(req: Request) {
}
/** Returns average wait time in milliseconds. */
export function getEstimatedWaitTime(partition: QueuePartition) {
export function getEstimatedWaitTime(partition: ModelFamily) {
const now = Date.now();
const recentWaits = waitTimes.filter(
(wt) => wt.partition === partition && now - wt.end < 300 * 1000
@@ -271,7 +281,7 @@ export function getEstimatedWaitTime(partition: QueuePartition) {
);
}
export function getQueueLength(partition: QueuePartition | "all" = "all") {
export function getQueueLength(partition: ModelFamily | "all" = "all") {
if (partition === "all") {
return queue.length;
}